diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 62c1412486..374bd7ed4d 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -42,6 +42,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: pip install -e .[docs]
+ - name: Check no warnings
+ run: mkdocs build --strict
+
- name: Set git credentials
run: |
git config --global user.name "${{ github.actor }}"
diff --git a/README.md b/README.md
index 7dff701ebf..728d69c0b4 100644
--- a/README.md
+++ b/README.md
@@ -78,6 +78,8 @@ Requires Python 3.9+
In addition, the following extras are available:
+### LLMs
+
- `anthropic`: for using models available in [Anthropic API](https://www.anthropic.com/api) via the `AnthropicLLM` integration.
- `cohere`: for using models available in [Cohere](https://cohere.ai/) via the `CohereLLM` integration.
- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/).
@@ -91,19 +93,32 @@ In addition, the following extras are available:
- `openai`: for using [OpenAI API](https://openai.com/blog/openai-api) models via the `OpenAILLM` integration, or the rest of the integrations based on OpenAI and relying on its client as `AnyscaleLLM`, `AzureOpenAILLM`, and `TogetherLLM`.
- `vertexai`: for using [Google Vertex AI](https://cloud.google.com/vertex-ai) proprietary models via the `VertexAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
+- `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers).
+
+### Structured generation
+
+- `outlines`: for using structured generation of LLMs with [outlines](https://github.com/outlines-dev/outlines).
+- `instructor`: for using structured generation of LLMs with [Instructor](https://github.com/jxnl/instructor/).
+
+### Data processing
+
+- `ray`: for scaling and distributing a pipeline with [Ray](https://github.com/ray-project/ray).
+- `faiss-cpu` and `faiss-gpu`: for generating sentence embeddings using [faiss](https://github.com/facebookresearch/faiss).
+- `text-clustering`: for using text clustering with [UMAP](https://github.com/lmcinnes/umap) and [Scikit-learn](https://github.com/scikit-learn/scikit-learn).
+- `minhash`: for using minhash for duplicate detection with [datasketch](https://github.com/datasketch/datasketch) and [nltk](https://github.com/nltk/nltk).
### Example
-To run the following example you must install `distilabel` with both `openai` extra:
+To run the following example you must install `distilabel` with the `hf-inference-endpoints` extra:
```sh
-pip install "distilabel[openai]" --upgrade
+pip install "distilabel[hf-inference-endpoints]" --upgrade
```
Then run:
```python
-from distilabel.llms import OpenAILLM
+from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -114,9 +129,14 @@ with Pipeline(
) as pipeline:
load_dataset = LoadDataFromHub(output_mappings={"prompt": "instruction"})
- generate_with_openai = TextGeneration(llm=OpenAILLM(model="gpt-3.5-turbo"))
+ text_generation = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ ),
+ )
- load_dataset >> generate_with_openai
+ load_dataset >> text_generation
if __name__ == "__main__":
distiset = pipeline.run(
@@ -125,7 +145,7 @@ if __name__ == "__main__":
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
"split": "test",
},
- generate_with_openai.name: {
+ text_generation.name: {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
@@ -135,6 +155,7 @@ if __name__ == "__main__":
},
},
)
+ distiset.push_to_hub(repo_id="distilabel-example")
```
## Badges
diff --git a/docs/api/embedding/embedding_gallery.md b/docs/api/embedding/embedding_gallery.md
new file mode 100644
index 0000000000..3eed3ab50e
--- /dev/null
+++ b/docs/api/embedding/embedding_gallery.md
@@ -0,0 +1,8 @@
+# Embedding Gallery
+
+This section contains the existing [`Embeddings`][distilabel.embeddings] subclasses implemented in `distilabel`.
+
+::: distilabel.embeddings
+ options:
+ filters:
+ - "!^Embeddings$"
\ No newline at end of file
diff --git a/docs/api/embedding/index.md b/docs/api/embedding/index.md
new file mode 100644
index 0000000000..675593e183
--- /dev/null
+++ b/docs/api/embedding/index.md
@@ -0,0 +1,7 @@
+# Embedding
+
+This section contains the API reference for the `distilabel` embeddings.
+
+For more information on how the [`Embeddings`][distilabel.steps.tasks.Task] works and see some examples.
+
+::: distilabel.embeddings.base
\ No newline at end of file
diff --git a/docs/api/errors.md b/docs/api/errors.md
new file mode 100644
index 0000000000..9ba2166302
--- /dev/null
+++ b/docs/api/errors.md
@@ -0,0 +1,8 @@
+# Errors
+
+This section contains the `distilabel` custom errors. Unlike [exceptions](exceptions.md), errors in `distilabel` are used to handle unexpected situations that can't be anticipated and that can't be handled in a controlled way.
+
+:::distilabel.errors.DistilabelError
+:::distilabel.errors.DistilabelUserError
+:::distilabel.errors.DistilabelTypeError
+:::distilabel.errors.DistilabelNotImplementedError
diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md
new file mode 100644
index 0000000000..5826756a22
--- /dev/null
+++ b/docs/api/exceptions.md
@@ -0,0 +1,7 @@
+# Exceptions
+
+This section contains the `distilabel` custom exceptions. Unlike [errors](errors.md), exceptions in `distilabel` are used to handle specific situations that can be anticipated and that can be handled in a controlled way internally by the library.
+
+:::distilabel.exceptions.DistilabelException
+:::distilabel.exceptions.DistilabelGenerationException
+:::distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException
diff --git a/docs/api/llm/anthropic.md b/docs/api/llm/anthropic.md
deleted file mode 100644
index 400571c6ed..0000000000
--- a/docs/api/llm/anthropic.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# AnthropicLLM
-
-::: distilabel.llms.anthropic
diff --git a/docs/api/llm/anyscale.md b/docs/api/llm/anyscale.md
deleted file mode 100644
index 90aa0cd6ea..0000000000
--- a/docs/api/llm/anyscale.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# AnyscaleLLM
-
-::: distilabel.llms.anyscale
diff --git a/docs/api/llm/azure.md b/docs/api/llm/azure.md
deleted file mode 100644
index faa127d5bc..0000000000
--- a/docs/api/llm/azure.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# AzureOpenAILLM
-
-::: distilabel.llms.azure
diff --git a/docs/api/llm/cohere.md b/docs/api/llm/cohere.md
deleted file mode 100644
index c7064b7a75..0000000000
--- a/docs/api/llm/cohere.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# CohereLLM
-
-::: distilabel.llms.cohere
diff --git a/docs/api/llm/groq.md b/docs/api/llm/groq.md
deleted file mode 100644
index 0a5264a772..0000000000
--- a/docs/api/llm/groq.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# GroqLLM
-
-::: distilabel.llms.groq
diff --git a/docs/api/llm/huggingface.md b/docs/api/llm/huggingface.md
deleted file mode 100644
index 30920255fe..0000000000
--- a/docs/api/llm/huggingface.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# Hugging Face
-
-This section contains the reference for Hugging Face integrations:
-
-::: distilabel.llms.huggingface.inference_endpoints
-::: distilabel.llms.huggingface.transformers
diff --git a/docs/api/llm/litellm.md b/docs/api/llm/litellm.md
deleted file mode 100644
index 90a4d2d631..0000000000
--- a/docs/api/llm/litellm.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# LiteLLM
-
-::: distilabel.llms.litellm
diff --git a/docs/api/llm/llamacpp.md b/docs/api/llm/llamacpp.md
deleted file mode 100644
index 02598c1a64..0000000000
--- a/docs/api/llm/llamacpp.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# LlamaCppLLM
-
-::: distilabel.llms.llamacpp
diff --git a/docs/api/llm/llm_gallery.md b/docs/api/llm/llm_gallery.md
new file mode 100644
index 0000000000..ad0b1b75f0
--- /dev/null
+++ b/docs/api/llm/llm_gallery.md
@@ -0,0 +1,10 @@
+# LLM Gallery
+
+This section contains the existing [`LLM`][distilabel.llms] subclasses implemented in `distilabel`.
+
+::: distilabel.llms
+ options:
+ filters:
+ - "!^LLM$"
+ - "!^AsyncLLM$"
+ - "!typing"
\ No newline at end of file
diff --git a/docs/api/llm/mistral.md b/docs/api/llm/mistral.md
deleted file mode 100644
index 069488eadd..0000000000
--- a/docs/api/llm/mistral.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# MistralLLM
-
-::: distilabel.llms.mistral
diff --git a/docs/api/llm/ollama.md b/docs/api/llm/ollama.md
deleted file mode 100644
index 25e4b662a1..0000000000
--- a/docs/api/llm/ollama.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# OllamaLLM
-
-::: distilabel.llms.ollama
diff --git a/docs/api/llm/openai.md b/docs/api/llm/openai.md
deleted file mode 100644
index 381306ad59..0000000000
--- a/docs/api/llm/openai.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# OpenAILLM
-
-::: distilabel.llms.openai
diff --git a/docs/api/llm/together.md b/docs/api/llm/together.md
deleted file mode 100644
index 6530165203..0000000000
--- a/docs/api/llm/together.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# TogetherLLM
-
-::: distilabel.llms.together
diff --git a/docs/api/llm/vertexai.md b/docs/api/llm/vertexai.md
deleted file mode 100644
index f8990605d8..0000000000
--- a/docs/api/llm/vertexai.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# VertexAILLM
-
-::: distilabel.llms.vertexai
diff --git a/docs/api/llm/vllm.md b/docs/api/llm/vllm.md
deleted file mode 100644
index 053b8535bb..0000000000
--- a/docs/api/llm/vllm.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# vLLM
-
-::: distilabel.llms.vllm
diff --git a/docs/api/pipeline/step_wrapper.md b/docs/api/pipeline/step_wrapper.md
new file mode 100644
index 0000000000..e68b64d1d9
--- /dev/null
+++ b/docs/api/pipeline/step_wrapper.md
@@ -0,0 +1,4 @@
+# Step Wrapper
+
+::: distilabel.pipeline.step_wrapper._StepWrapper
+::: distilabel.pipeline.step_wrapper._StepWrapperException
diff --git a/docs/api/pipeline/utils.md b/docs/api/pipeline/utils.md
deleted file mode 100644
index c8ad6f2e54..0000000000
--- a/docs/api/pipeline/utils.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Pipeline Utils
-
-::: distilabel.pipeline.utils
diff --git a/docs/api/step/typing.md b/docs/api/step/typing.md
new file mode 100644
index 0000000000..1a86e7dac1
--- /dev/null
+++ b/docs/api/step/typing.md
@@ -0,0 +1,3 @@
+# Step Typing
+
+::: distilabel.steps.typing
\ No newline at end of file
diff --git a/docs/api/step_gallery/columns.md b/docs/api/step_gallery/columns.md
index 9e5392d850..7b75053e6a 100644
--- a/docs/api/step_gallery/columns.md
+++ b/docs/api/step_gallery/columns.md
@@ -6,3 +6,4 @@ This section contains the existing steps intended to be used for common column o
::: distilabel.steps.columns.keep
::: distilabel.steps.columns.merge
::: distilabel.steps.columns.group
+::: distilabel.steps.columns.utils
diff --git a/docs/api/step_gallery/extra.md b/docs/api/step_gallery/extra.md
index e310e45d4b..3d3e6f9c57 100644
--- a/docs/api/step_gallery/extra.md
+++ b/docs/api/step_gallery/extra.md
@@ -1,6 +1,11 @@
# Extra
-::: distilabel.steps.generators.data
-::: distilabel.steps.deita
-::: distilabel.steps.formatting
-::: distilabel.steps.typing
+::: distilabel.steps
+ options:
+ filters:
+ - "!Argilla"
+ - "!Columns"
+ - "!From(Disk|FileSystem)"
+ - "!Hub"
+ - "![Ss]tep"
+ - "!typing"
diff --git a/docs/api/step_gallery/hugging_face.md b/docs/api/step_gallery/hugging_face.md
index 42fb85e795..c801aca86b 100644
--- a/docs/api/step_gallery/hugging_face.md
+++ b/docs/api/step_gallery/hugging_face.md
@@ -5,3 +5,4 @@ This section contains the existing steps integrated with `Hugging Face` so as to
::: distilabel.steps.LoadDataFromDisk
::: distilabel.steps.LoadDataFromFileSystem
::: distilabel.steps.LoadDataFromHub
+::: distilabel.steps.PushToHub
\ No newline at end of file
diff --git a/docs/api/task_gallery/index.md b/docs/api/task/task_gallery.md
similarity index 100%
rename from docs/api/task_gallery/index.md
rename to docs/api/task/task_gallery.md
diff --git a/docs/assets/images/sections/caching/caching_1.png b/docs/assets/images/sections/caching/caching_1.png
new file mode 100644
index 0000000000..cde228769b
Binary files /dev/null and b/docs/assets/images/sections/caching/caching_1.png differ
diff --git a/docs/assets/images/sections/caching/caching_2.png b/docs/assets/images/sections/caching/caching_2.png
new file mode 100644
index 0000000000..8f0d9d4d51
Binary files /dev/null and b/docs/assets/images/sections/caching/caching_2.png differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_1.png b/docs/assets/images/sections/caching/caching_pipe_1.png
deleted file mode 100644
index f41f38a601..0000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_1.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_2.png b/docs/assets/images/sections/caching/caching_pipe_2.png
deleted file mode 100644
index 22adebc1ad..0000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_2.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_3.png b/docs/assets/images/sections/caching/caching_pipe_3.png
deleted file mode 100644
index b41a3a6c8b..0000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_3.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_4.png b/docs/assets/images/sections/caching/caching_pipe_4.png
deleted file mode 100644
index 12ea2c7f2c..0000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_4.png and /dev/null differ
diff --git a/docs/assets/images/sections/community/compare-pull-request.PNG b/docs/assets/images/sections/community/compare-pull-request.PNG
new file mode 100644
index 0000000000..ace5c010b9
Binary files /dev/null and b/docs/assets/images/sections/community/compare-pull-request.PNG differ
diff --git a/docs/assets/images/sections/community/create-branch.PNG b/docs/assets/images/sections/community/create-branch.PNG
new file mode 100644
index 0000000000..24dfc19755
Binary files /dev/null and b/docs/assets/images/sections/community/create-branch.PNG differ
diff --git a/docs/assets/images/sections/community/edit-file.PNG b/docs/assets/images/sections/community/edit-file.PNG
new file mode 100644
index 0000000000..c76e535d7f
Binary files /dev/null and b/docs/assets/images/sections/community/edit-file.PNG differ
diff --git a/docs/assets/images/sections/how_to_guides/basic/pipeline.png b/docs/assets/images/sections/how_to_guides/basic/pipeline.png
new file mode 100644
index 0000000000..0718e1c28c
Binary files /dev/null and b/docs/assets/images/sections/how_to_guides/basic/pipeline.png differ
diff --git a/docs/assets/images/sections/how_to_guides/tasks/task_print.png b/docs/assets/images/sections/how_to_guides/tasks/task_print.png
new file mode 100644
index 0000000000..95498c8c6c
Binary files /dev/null and b/docs/assets/images/sections/how_to_guides/tasks/task_print.png differ
diff --git a/docs/assets/pipelines/arena-hard.png b/docs/assets/pipelines/arena-hard.png
new file mode 100644
index 0000000000..a6a208c373
Binary files /dev/null and b/docs/assets/pipelines/arena-hard.png differ
diff --git a/docs/assets/pipelines/clair.png b/docs/assets/pipelines/clair.png
new file mode 100644
index 0000000000..c80e801f90
Binary files /dev/null and b/docs/assets/pipelines/clair.png differ
diff --git a/docs/assets/pipelines/clean-dataset.png b/docs/assets/pipelines/clean-dataset.png
new file mode 100644
index 0000000000..1f73e19ce6
Binary files /dev/null and b/docs/assets/pipelines/clean-dataset.png differ
diff --git a/docs/assets/pipelines/deepseek.png b/docs/assets/pipelines/deepseek.png
new file mode 100644
index 0000000000..7bcbb54df9
Binary files /dev/null and b/docs/assets/pipelines/deepseek.png differ
diff --git a/docs/assets/pipelines/deita.png b/docs/assets/pipelines/deita.png
new file mode 100644
index 0000000000..b552cf0c4b
Binary files /dev/null and b/docs/assets/pipelines/deita.png differ
diff --git a/docs/assets/pipelines/generate-preference-dataset.png b/docs/assets/pipelines/generate-preference-dataset.png
new file mode 100644
index 0000000000..fce35c8195
Binary files /dev/null and b/docs/assets/pipelines/generate-preference-dataset.png differ
diff --git a/docs/assets/pipelines/instruction_backtranslation.png b/docs/assets/pipelines/instruction_backtranslation.png
new file mode 100644
index 0000000000..152a0063d9
Binary files /dev/null and b/docs/assets/pipelines/instruction_backtranslation.png differ
diff --git a/docs/assets/pipelines/knowledge_graphs.png b/docs/assets/pipelines/knowledge_graphs.png
new file mode 100644
index 0000000000..f142e76c8f
Binary files /dev/null and b/docs/assets/pipelines/knowledge_graphs.png differ
diff --git a/docs/assets/pipelines/prometheus.png b/docs/assets/pipelines/prometheus.png
new file mode 100644
index 0000000000..10335f15c3
Binary files /dev/null and b/docs/assets/pipelines/prometheus.png differ
diff --git a/docs/assets/pipelines/sentence-transformer.png b/docs/assets/pipelines/sentence-transformer.png
new file mode 100644
index 0000000000..690dc9379d
Binary files /dev/null and b/docs/assets/pipelines/sentence-transformer.png differ
diff --git a/docs/assets/pipelines/ultrafeedback.png b/docs/assets/pipelines/ultrafeedback.png
new file mode 100644
index 0000000000..852edc4c36
Binary files /dev/null and b/docs/assets/pipelines/ultrafeedback.png differ
diff --git a/docs/assets/tutorials-assets/overview-apigen.jpg b/docs/assets/tutorials-assets/overview-apigen.jpg
new file mode 100644
index 0000000000..61deefac9a
Binary files /dev/null and b/docs/assets/tutorials-assets/overview-apigen.jpg differ
diff --git a/docs/index.md b/docs/index.md
index 37cf6f9fdf..ce76c96956 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -38,21 +38,39 @@ hide:
Distilabel is the framework for synthetic data and AI feedback for engineers who need fast, reliable and scalable pipelines based on verified research papers.
-If you just want to get started, we recommend you check the [documentation](http://distilabel.argilla.io/). Curious, and want to know more? Keep reading!
+
+
+- __Get started in 5 minutes!__
+
+ ---
+
+ Install distilabel with `pip` and run your first `Pipeline` to generate and evaluate synthetic data.
+
+ [:octicons-arrow-right-24: Quickstart](./sections/getting_started/quickstart.md)
+
+- __How-to guides__
+
+ ---
+
+ Get familiar with the basics of distilabel. Learn how to define `steps`, `tasks` and `llms` and run your `Pipeline`.
+
+ [:octicons-arrow-right-24: Learn more](./sections/how_to_guides/index.md)
+
+
## Why use distilabel?
Distilabel can be used for generating synthetic data and AI feedback for a wide variety of projects including traditional predictive NLP (classification, extraction, etc.), or generative and large language model scenarios (instruction following, dialogue generation, judging etc.). Distilabel's programmatic approach allows you to build scalable pipelines for data generation and AI feedback. The goal of distilabel is to accelerate your AI development by quickly generating high-quality, diverse datasets based on verified research methodologies for generating and judging with AI feedback.
-### Improve your AI output quality through data quality
+Improve your AI output quality through data quality
Compute is expensive and output quality is important. We help you **focus on data quality**, which tackles the root cause of both of these problems at once. Distilabel helps you to synthesize and judge data to let you spend your valuable time **achieving and keeping high-quality standards for your synthetic data**.
-### Take control of your data and models
+Take control of your data and models
**Ownership of data for fine-tuning your own LLMs** is not easy but distilabel can help you to get started. We integrate **AI feedback from any LLM provider out there** using one unified API.
-### Improve efficiency by quickly iterating on the right research and LLMs
+Improve efficiency by quickly iterating on the right data and models
Synthesize and judge data with **latest research papers** while ensuring **flexibility, scalability and fault tolerance**. So you can focus on improving your data and training your models.
diff --git a/docs/sections/community/contributor.md b/docs/sections/community/contributor.md
new file mode 100644
index 0000000000..180c46929e
--- /dev/null
+++ b/docs/sections/community/contributor.md
@@ -0,0 +1,161 @@
+---
+description: This is a step-by-step guide to help you contribute to the distilabel project. We are excited to have you on board! 🚀
+hide:
+ - footer
+---
+
+Thank you for investing your time in contributing to the project! Any contribution you make will be reflected in the most recent version of distilabel 🤩.
+
+??? Question "New to contributing in general?"
+ If you're a new contributor, read the [README](https://github.com/argilla-io/distilabel/blob/develop/README.md) to get an overview of the project. In addition, here are some resources to help you get started with open-source contributions:
+
+ * **Discord**: You are welcome to join the [distilabel Discord community](http://hf.co/join/discord), where you can keep in touch with other users, contributors and the distilabel team. In the following [section](#first-contact-in-discord), you can find more information on how to get started in Discord.
+ * **Git**: This is a very useful tool to keep track of the changes in your files. Using the command-line interface (CLI), you can make your contributions easily. For that, you need to have it [installed and updated](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) on your computer.
+ * **GitHub**: It is a platform and cloud-based service that uses git and allows developers to collaborate on projects. To contribute to distilabel, you'll need to create an account. Check the [Contributor Workflow with Git and Github](#contributor-workflow-with-git-and-github) for more info.
+ * **Developer Documentation**: To collaborate, you'll need to set up an efficient environment. Check the [Installation](../getting_started/installation.md) guide to know how to do it.
+
+## First Contact in Discord
+
+Discord is a handy tool for more casual conversations and to answer day-to-day questions. As part of Hugging Face, we have set up some distilabel channels on the server. Click [here](http://hf.co/join/discord) to join the Hugging Face Discord community effortlessly.
+
+When part of the Hugging Face Discord, you can select "Channels & roles" and select "Argilla" along with any of the other groups that are interesting to you. "Argilla" will cover anything about argilla and distilabel. You can join the following channels:
+
+* **#argilla-distilabel-announcements**: 📣 Stay up-to-date.
+* **#argilla-distilabel-general**: 💬 For general discussions.
+* **#argilla-distilabel-help**: 🙋♀️ Need assistance? We're always here to help. Select the appropriate label (argilla or distilabel) for your issue and post it.
+
+So now there is only one thing left to do: introduce yourself and talk to the community. You'll always be welcome! 🤗👋
+
+
+## Contributor Workflow with Git and GitHub
+
+If you're working with distilabel and suddenly a new idea comes to your mind or you find an issue that can be improved, it's time to actively participate and contribute to the project!
+
+### Report an issue
+
+If you spot a problem, [search if an issue already exists](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue), you can use the `Label` filter. If that is the case, participate in the conversation. If it does not exist, create an issue by clicking on `New Issue`. This will show various templates; choose the one that best suits your issue. Once you choose one, you will need to fill it in following the guidelines. Try to be as clear as possible. In addition, you can assign yourself to the issue and add or choose the right labels. Finally, click on `Submit new issue`.
+
+
+### Work with a fork
+
+#### Fork the distilabel repository
+
+After having reported the issue, you can start working on it. For that, you will need to create a fork of the project. To do that, click on the `Fork` button. Now, fill in the information. Remember to uncheck the `Copy develop branch only` if you are going to work in or from another branch (for instance, to fix documentation, the `main` branch is used). Then, click on `Create fork`.
+
+You will be redirected to your fork. You can see that you are in your fork because the name of the repository will be your `username/distilabel`, and it will indicate `forked from argilla-io/distilabel`.
+
+
+#### Clone your forked repository
+
+In order to make the required adjustments, clone the forked repository to your local machine. Choose the destination folder and run the following command:
+
+```sh
+git clone https://github.com/[your-github-username]/distilabel.git
+cd distilabel
+```
+
+To keep your fork’s main/develop branch up to date with our repo, add it as an upstream remote branch.
+
+```sh
+git remote add upstream https://github.com/argilla-io/distilabel.git
+```
+
+
+### Create a new branch
+
+For each issue you're addressing, it's advisable to create a new branch. GitHub offers a straightforward method to streamline this process.
+
+> ⚠️ Never work directly on the `main` or `develop` branch. Always create a new branch for your changes.
+
+Navigate to your issue, and on the right column, select `Create a branch`.
+
+![Create a branch](../../assets/images/sections/community/create-branch.PNG)
+
+After the new window pops up, the branch will be named after the issue and include a prefix such as feature/, bug/, or docs/ to facilitate quick recognition of the issue type. In the `Repository destination`, pick your fork ( [your-github-username]/distilabel), and then select `Change branch source` to specify the source branch for creating the new one. Complete the process by clicking `Create branch`.
+
+> 🤔 Remember that the `main` branch is only used to work with the documentation. For any other changes, use the `develop` branch.
+
+Now, locally, change to the new branch you just created.
+
+```sh
+git fetch origin
+git checkout [branch-name]
+```
+
+### Make changes and push them
+
+Make the changes you want in your local repository, and test that everything works and you are following the guidelines.
+
+Once you have finished, you can check the status of your repository and synchronize with the upstreaming repo with the following command:
+
+```sh
+# Check the status of your repository
+git status
+
+# Synchronize with the upstreaming repo
+git checkout [branch-name]
+git rebase [default-branch]
+```
+
+If everything is right, we need to commit and push the changes to your fork. For that, run the following commands:
+
+```sh
+# Add the changes to the staging area
+git add filename
+
+# Commit the changes by writing a proper message
+git commit -m "commit-message"
+
+# Push the changes to your fork
+git push origin [branch-name]
+```
+
+When pushing, you will be asked to enter your GitHub login credentials. Once the push is complete, all local commits will be on your GitHub repository.
+
+
+### Create a pull request
+
+Come back to GitHub, navigate to the original repository where you created your fork, and click on `Compare & pull request`.
+
+![compare-and-pr](../../assets/images/sections/community/compare-pull-request.PNG)
+
+First, click on `compare across forks` and select the right repositories and branches.
+
+> In the base repository, keep in mind that you should select either `main` or `develop` based on the modifications made. In the head repository, indicate your forked repository and the branch corresponding to the issue.
+
+Then, fill in the pull request template. You should add a prefix to the PR name, as we did with the branch above. If you are working on a new feature, you can name your PR as `feat: TITLE`. If your PR consists of a solution for a bug, you can name your PR as `bug: TITLE`. And, if your work is for improving the documentation, you can name your PR as `docs: TITLE`.
+
+In addition, on the right side, you can select a reviewer (for instance, if you discussed the issue with a member of the team) and assign the pull request to yourself. It is highly advisable to add labels to PR as well. You can do this again by the labels section right on the screen. For instance, if you are addressing a bug, add the `bug` label, or if the PR is related to the documentation, add the `documentation` label. This way, PRs can be easily filtered.
+
+Finally, fill in the template carefully and follow the guidelines. Remember to link the original issue and enable the checkbox to allow maintainer edits so the branch can be updated for a merge. Then, click on `Create pull request`.
+
+For the PR body, ensure you give a description of what the PR contains, and add examples if possible (and if they apply to the contribution) to help with the review process. You can take a look at [#PR 974](https://github.com/argilla-io/distilabel/pull/974) or [#PR 983](https://github.com/argilla-io/distilabel/pull/983) for examples of typical PRs.
+
+
+### Review your pull request
+
+Once you submit your PR, a team member will review your proposal. We may ask questions, request additional information, or ask for changes to be made before a PR can be merged, either using [suggested changes](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/incorporating-feedback-in-your-pull-request) or pull request comments.
+
+You can apply the changes directly through the UI (check the files changed and click on the right-corner three dots; see image below) or from your fork, and then commit them to your branch. The PR will be updated automatically, and the suggestions will appear as `outdated`.
+
+![edit-file-from-UI](../../assets/images/sections/community/edit-file.PNG)
+
+> If you run into any merge issues, check out this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues.
+
+
+### Your PR is merged!
+
+Congratulations 🎉🎊 We thank you 🤩
+
+Once your PR is merged, your contributions will be publicly visible on the [distilabel GitHub](https://github.com/argilla-io/distilabel#contributors).
+
+Additionally, we will include your changes in the next release based on our [development branch](https://github.com/argilla-io/argilla/tree/develop).
+
+## Additional resources
+
+Here are some helpful resources for your reference.
+
+* [Configuring Discord](https://support.discord.com/hc/en-us/categories/115000217151), a guide to learning how to get started with Discord.
+* [Pro Git](https://git-scm.com/book/en/v2), a book to learn Git.
+* [Git in VSCode](https://code.visualstudio.com/docs/sourcecontrol/overview), a guide to learning how to easily use Git in VSCode.
+* [GitHub Skills](https://skills.github.com/), an interactive course for learning GitHub.
\ No newline at end of file
diff --git a/docs/sections/community/developer_documentation.md b/docs/sections/community/developer_documentation.md
new file mode 100644
index 0000000000..ccc0795f87
--- /dev/null
+++ b/docs/sections/community/developer_documentation.md
@@ -0,0 +1,104 @@
+---
+description: This is a step-by-step guide to help you develop distilabel.
+hide:
+ - footer
+---
+
+Thank you for investing your time in contributing to the project!
+
+If you don't have the repository locally, and need any help, go to the [contributor guide](../community/contributor.md) and read the contributor workflow with Git and GitHub first.
+
+## Set up the Python environment
+
+To work on the `distilabel`, you must install the package on your system.
+
+!!! Tip
+ This guide will use `uv`, but `pip` and `venv` can be used as well, this guide can work quite similar with both options.
+
+From the root of the cloned Distilabel repository, you should move to the distilabel folder in your terminal.
+
+```bash
+cd distilabel
+```
+
+### Create a virtual environment
+
+The first step will be creating a virtual environment to keep our dependencies isolated. Here we are choosing `python 3.11` ([uv venv](https://docs.astral.sh/uv/pip/environments/) documentation), and then activate it:
+
+```bash
+uv venv .venv --python 3.11
+source .venv/bin/activate
+```
+
+### Install the project
+
+Installing from local (we are using [`uv pip`](https://docs.astral.sh/uv/pip/packages/)):
+
+```bash
+uv pip install -e .
+```
+
+We have extra dependencies with their name, depending on the part you are working on, you may want to install some dependency (take a look at `pyproject.toml` in the repo to see all the extra dependencies):
+
+```bash
+uv pip install -e ".[vllm,outlines]"
+```
+
+### Linting and formatting
+
+To maintain a consistent code format, install the pre-commit hooks to run before each commit automatically (we rely heavily on [`ruff`](https://docs.astral.sh/ruff/)):
+
+```bash
+uv pip install -e ".[dev]"
+pre-commit install
+```
+
+### Running tests
+
+All the changes you add to the codebase should come with tests, either `unit` or `integration` tests, depending on the type of change, which are placed under `tests/unit` and `tests/integration` respectively.
+
+Start by installing the tests dependencies:
+
+```bash
+uv pip install ".[tests]"
+```
+
+Running the whole tests suite may take some time, and you will need all the dependencies installed, so just run your tests, and the whole tests suite will be run for you in the CI:
+
+```bash
+# Run specific tests
+pytest tests/unit/steps/generators/test_data.py
+```
+
+## Set up the documentation
+
+To contribute to the documentation and generate it locally, ensure you have installed the development dependencies:
+
+```bash
+uv pip install -e ".[docs]"
+```
+
+And run the following command to create the development server with `mkdocs`:
+
+```bash
+mkdocs serve
+```
+
+### Documentation guidelines
+
+As mentioned, we use mkdocs to build the documentation. You can write the documentation in `markdown` format, and it will automatically be converted to HTML. In addition, you can include elements such as tables, tabs, images, and others, as shown in this guide. We recommend following these guidelines:
+
+- Use clear and concise language: Ensure the documentation is easy to understand for all users by using straightforward language and including meaningful examples. Images are not easy to maintain, so use them only when necessary and place them in the appropriate folder within the docs/assets/images directory.
+
+- Verify code snippets: Double-check that all code snippets are correct and runnable.
+
+- Review spelling and grammar: Check the spelling and grammar of the documentation.
+
+- Update the table of contents: If you add a new page, include it in the relevant index.md or the mkdocs.yml file.
+
+### Components gallery
+
+The components gallery section of the documentation is automatically generated thanks to a custom plugin, it will be run when `mkdocs serve` is called. This guide to the steps helps us visualize each step, as well as examples of use.
+
+!!! Note
+ Changes done to the docstrings of `Steps/Tasks/LLMs` won't appear in the components gallery automatically, you will have to stop the `mkdocs` server and run it again to see the changes, everything else is reloaded automatically.
diff --git a/docs/sections/getting_started/faq.md b/docs/sections/getting_started/faq.md
index 27768a3c6f..7a78126c46 100644
--- a/docs/sections/getting_started/faq.md
+++ b/docs/sections/getting_started/faq.md
@@ -7,20 +7,20 @@ hide:
# Frequent Asked Questions (FAQ)
??? faq "How can I rename the columns in a batch?"
- Every [`Step`][distilabel.steps.base.Step] has both `input_mappings` and `output_mappings` attributes, that can be used to rename the columns in each batch.
+ Every [`Step`][distilabel.steps.base.Step] has both `input_mappings` and `output_mappings` attributes that can be used to rename the columns in each batch.
- But `input_mappings` will only map, meaning that if you have a batch with the column `A` and you want to rename to `B`, you should use `input_mappings={"A": "B"}`, but that will only be applied to that specific [`Step`][distilabel.steps.base.Step] meaning that the next step in the pipeline will still have the column `A` instead of `B`.
+ But `input_mappings` will only map, meaning that if you have a batch with the column `A` and you want to rename it to `B`, you should use `input_mappings={"A": "B"}`, but that will only be applied to that specific [`Step`][distilabel.steps.base.Step] meaning that the next step in the pipeline will still have the column `A` instead of `B`.
While `output_mappings` will indeed apply the rename, meaning that if the [`Step`][distilabel.steps.base.Step] produces the column `A` and you want to rename to `B`, you should use `output_mappings={"A": "B"}`, and that will be applied to the next [`Step`][distilabel.steps.base.Step] in the pipeline.
??? faq "Will the API Keys be exposed when sharing the pipeline?"
No, those will be masked out using `pydantic.SecretStr`, meaning that those won't be exposed when sharing the pipeline.
- This also means that if you want to re-run your own pipeline and the API keys have not been provided via environment variable but either via attribute or runtime parameter, you will need to provide them again.
+ This also means that if you want to re-run your own pipeline and the API keys have not been provided via environment variable but either via an attribute or runtime parameter, you will need to provide them again.
??? faq "Does it work for Windows?"
- Yes, but you may need to set the `multiprocessing` context in advance, to ensure that the `spawn` method is used, since the default method `fork` is not available on Windows.
+ Yes, but you may need to set the `multiprocessing` context in advance to ensure that the `spawn` method is used since the default method `fork` is not available on Windows.
```python
import multiprocessing as mp
@@ -29,16 +29,34 @@ hide:
```
??? faq "Will the custom Steps / Tasks / LLMs be serialized too?"
- No, at the moment only the references to the classes within the `distilabel` library will be serialized, meaning that if you define a custom class used within the pipeline, the serialization won't break, but the deserialize will fail since the class won't be available, unless used from the same file.
+ No, at the moment, only the references to the classes within the `distilabel` library will be serialized, meaning that if you define a custom class used within the pipeline, the serialization won't break, but the deserialize will fail since the class won't be available unless used from the same file.
??? faq "What happens if `Pipeline.run` fails? Do I lose all the data?"
- No, indeed we're using a cache mechanism to store all the intermediate results in disk, so that if a [`Step`][distilabel.steps.base.Step] fails, the pipeline can be re-run from that point without losing the data, only if nothing is changed in the `Pipeline`.
+ No, indeed, we're using a cache mechanism to store all the intermediate results in the disk so, if a [`Step`][distilabel.steps.base.Step] fails; the pipeline can be re-run from that point without losing the data, only if nothing is changed in the `Pipeline`.
All the data will be stored in `.cache/distilabel`, but the only data that will persist at the end of the `Pipeline.run` execution is the one from the leaf step/s, so bear that in mind.
For more information on the caching mechanism in `distilabel`, you can check the [Learn - Advanced - Caching](../how_to_guides/advanced/caching.md) section.
- Also note that when running a [`Step`][distilabel.steps.base.Step] or a [`Task`][distilabel.steps.tasks.Task] standalone, the cache mechanism won't be used, so if you want to use that, you should use the `Pipeline` context manager.
+ Also, note that when running a [`Step`][distilabel.steps.base.Step] or a [`Task`][distilabel.steps.tasks.Task] standalone, the cache mechanism won't be used, so if you want to use that, you should use the `Pipeline` context manager.
??? faq "How can I use the same `LLM` across several tasks without having to load it several times?"
You can serve the LLM using a solution like TGI or vLLM, and then connect to it using an `AsyncLLM` client like `InferenceEndpointsLLM` or `OpenAILLM`. Please refer to [Serving LLMs guide](../how_to_guides/advanced/serving_an_llm_for_reuse.md) for more information.
+
+??? faq "Can `distilabel` be used with [OpenAI Batch API](https://platform.openai.com/docs/guides/batch)?"
+ Yes, `distilabel` is integrated with OpenAI Batch API via [OpenAILLM][distilabel.llms.openai.OpenAILLM]. Check [LLMs - Offline Batch Generation](../how_to_guides/basic/llm/index.md#offline-batch-generation) for a small example on how to use it and [Advanced - Offline Batch Generation](../how_to_guides/advanced/offline_batch_generation.md) for a more detailed guide.
+
+??? faq "Prevent overloads on [Free Serverless Endpoints][distilabel.llms.huggingface.InferenceEndpointsLLM]"
+ When running a task using the [InferenceEndpointsLLM][distilabel.llms.huggingface.InferenceEndpointsLLM] with Free Serverless Endpoints, you may be facing some errors such as `Model is overloaded` if you let the batch size to the default (set at 50). To fix the issue, lower the value or even better set `input_batch_size=1` in your task. It may take a longer time to finish, but please remember this is a free service.
+
+ ```python
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.steps import TextGeneration
+
+ TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ input_batch_size=1
+ )
+ ```
diff --git a/docs/sections/getting_started/installation.md b/docs/sections/getting_started/installation.md
index 804aa8de7e..54e130b7fa 100644
--- a/docs/sections/getting_started/installation.md
+++ b/docs/sections/getting_started/installation.md
@@ -6,9 +6,6 @@ hide:
# Installation
-!!! NOTE
- Since `distilabel` v1.0.0 was recently released, we refactored most of the stuff, so the installation below only applies to `distilabel` v1.0.0 and above.
-
You will need to have at least Python 3.9 or higher, up to Python 3.12, since support for the latter is still a work in progress.
To install the latest release of the package from PyPI you can use the following command:
@@ -30,6 +27,8 @@ pip install "distilabel @ git+https://github.com/argilla-io/distilabel.git@devel
Additionally, as part of `distilabel` some extra dependencies are available, mainly to add support for some of the LLM integrations we support. Here's a list of the available extras:
+### LLMs
+
- `anthropic`: for using models available in [Anthropic API](https://www.anthropic.com/api) via the `AnthropicLLM` integration.
- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/).
@@ -56,11 +55,29 @@ Additionally, as part of `distilabel` some extra dependencies are available, mai
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
+- `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers).
+
+### Data processing
+
+- `ray`: for scaling and distributing a pipeline with [Ray](https://github.com/ray-project/ray).
+
+- `faiss-cpu` and `faiss-gpu`: for generating sentence embeddings using [faiss](https://github.com/facebookresearch/faiss).
+
+- `minhash`: for using minhash for duplicate detection with [datasketch](https://github.com/datasketch/datasketch) and [nltk](https://github.com/nltk/nltk).
+
+- `text-clustering`: for using text clustering with [UMAP](https://github.com/lmcinnes/umap) and [Scikit-learn](https://github.com/scikit-learn/scikit-learn).
+
+### Structured generation
+
+- `outlines`: for using structured generation of LLMs with [outlines](https://github.com/outlines-dev/outlines).
+
+- `instructor`: for using structured generation of LLMs with [Instructor](https://github.com/jxnl/instructor/).
+
## Recommendations / Notes
The [`mistralai`](https://github.com/mistralai/client-python) dependency requires Python 3.9 or higher, so if you're willing to use the `distilabel.llms.MistralLLM` implementation, you will need to have Python 3.9 or higher.
-In some cases like [`transformers`](https://github.com/huggingface/transformers) and [`vllm`](https://github.com/vllm-project/vllm) the installation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) is recommended if you are using a GPU accelerator, since it will speed up the inference process, but the installation needs to be done separately, as it's not included in the `distilabel` dependencies.
+In some cases like [`transformers`](https://github.com/huggingface/transformers) and [`vllm`](https://github.com/vllm-project/vllm), the installation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) is recommended if you are using a GPU accelerator since it will speed up the inference process, but the installation needs to be done separately, as it's not included in the `distilabel` dependencies.
```sh
pip install flash-attn --no-build-isolation
diff --git a/docs/sections/getting_started/quickstart.md b/docs/sections/getting_started/quickstart.md
index 4fd7de607b..7af9bca8f0 100644
--- a/docs/sections/getting_started/quickstart.md
+++ b/docs/sections/getting_started/quickstart.md
@@ -4,14 +4,38 @@ hide:
- toc
---
+
+
+
+
# Quickstart
-To start off, `distilabel` is a framework for building pipelines for generating synthetic data using LLMs, that defines a [`Pipeline`][distilabel.pipeline.Pipeline] which orchestrates the execution of the [`Step`][distilabel.steps.base.Step] subclasses, and those will be connected as nodes in a Direct Acyclic Graph (DAG).
+Distilabel provides all the tools you need to your scalable and reliable pipelines for synthetic data generation and AI-feedback. Pipelines are used to generate data, evaluate models, manipulate data, or any other general task. They are made up of different components: Steps, Tasks and LLMs, which are chained together in a directed acyclic graph (DAG).
+
+- **Steps**: These are the building blocks of your pipeline. Normal steps are used for basic executions like loading data, applying some transformations, or any other general task.
+- **Tasks**: These are steps that rely on LLMs and prompts to perform generative tasks. For example, they can be used to generate data, evaluate models or manipulate data.
+- **LLMs**: These are the models that will perform the task. They can be local or remote models, and open-source or commercial models.
+
+Pipelines are designed to be scalable and reliable. They can be executed in a distributed manner, and they can be cached and recovered. This is useful when dealing with large datasets or when you want to ensure that your pipeline is reproducible.
+
+Besides that, pipelines are designed to be modular and flexible. You can easily add new steps, tasks, or LLMs to your pipeline, and you can also easily modify or remove them. An example architecture of a pipeline to generate a dataset of preferences is the following:
-That being said, in this guide we will walk you through the process of creating a simple pipeline that uses the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text based on the dataset using the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task.
+## Installation
+
+To install the latest release with `hf-inference-endpoints` extra of the package from PyPI you can use the following command:
+
+```sh
+pip install distilabel[hf-inference-endpoints] --upgrade
+```
+
+## Define a pipeline
+
+In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task.
+
+> You can check the available models in the [Hugging Face Model Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) and filter by `Inference status`.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -21,19 +45,22 @@ with Pipeline( # (1)
description="A simple text generation pipeline",
) as pipeline: # (2)
load_dataset = LoadDataFromHub( # (3)
- name="load_dataset",
output_mappings={"prompt": "instruction"},
)
text_generation = TextGeneration( # (4)
- name="text_generation",
- llm=OpenAILLM(model="gpt-3.5-turbo"), # (5)
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ ), # (5)
+ system_prompt="You are a creative AI Assistant writer.",
+ template="Follow the following instruction: {{ instruction }}" # (6)
)
- load_dataset >> text_generation # (6)
+ load_dataset >> text_generation # (7)
if __name__ == "__main__":
- distiset = pipeline.run( # (7)
+ distiset = pipeline.run( # (8)
parameters={
load_dataset.name: {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
@@ -49,49 +76,23 @@ if __name__ == "__main__":
},
},
)
- distiset.push_to_hub(repo_id="distilabel-example") # (8)
+ distiset.push_to_hub(repo_id="distilabel-example") # (9)
```
1. We define a [`Pipeline`][distilabel.pipeline.Pipeline] with the name `simple-text-generation-pipeline` and a description `A simple text generation pipeline`. Note that the `name` is mandatory and will be used to calculate the `cache` signature path, so changing the name will change the cache path and will be identified as a different pipeline.
2. We are using the [`Pipeline`][distilabel.pipeline.Pipeline] context manager, meaning that every [`Step`][distilabel.steps.base.Step] subclass that is defined within the context manager will be added to the pipeline automatically.
-3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will basically produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field.
-
-4. We define a [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`OpenAILLM`][distilabel.llms.OpenAILLM] class with the model `gpt-3.5-turbo`.
-
-5. We define the [`OpenAILLM`][distilabel.llms.OpenAILLM] class with the model `gpt-3.5-turbo` that will be used by the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task. In this case, since the [`OpenAILLM`][distilabel.llms.OpenAILLM] is used, we assume that the `OPENAI_API_KEY` environment variable is set, and the OpenAI API will be used to generate the text.
+3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field.
-6. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.
+4. We define a [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
-7. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.
+5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
-8. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.
-
-## Minimal example
-
-`distilabel` gives a lot of flexibility to create your pipelines, but to start right away, you can omit a lot of the details and let default values:
-
-```python
-from distilabel.llms import InferenceEndpointsLLM
-from distilabel.pipeline import Pipeline
-from distilabel.steps.tasks import TextGeneration
-from datasets import load_dataset
-
-
-dataset = load_dataset("distilabel-internal-testing/instruction-dataset-mini", split="test")
-
-with Pipeline() as pipeline: # (1)
- TextGeneration(llm=InferenceEndpointsLLM(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct")) # (2)
-
-
-if __name__ == "__main__":
- distiset = pipeline.run(dataset=dataset) # (3)
- distiset.push_to_hub(repo_id="distilabel-example")
-```
+6. Both `system_prompt` and `template` are optional fields. The `template` must be informed as a string following the [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) template format, and the fields that appear there ("instruction" in this case, which corresponds to the default) must be informed in the `columns` attribute. The component gallery for [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) has examples to get you started.
-1. The [`Pipeline`][distilabel.pipeline.Pipeline] can take no arguments and generate a default name on it's own that will be tracked internally.
+7. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.
-2. Just as with the [`Pipeline`][distilabel.pipeline.Pipeline], the [`Step`][distilabel.steps.base.Step]s don't explicitly need a name.
+8. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.
-3. You can generate the dataset as you would normally do with Hugging Face and pass the dataset to the run method.
+9. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.
diff --git a/docs/sections/how_to_guides/advanced/caching.md b/docs/sections/how_to_guides/advanced/caching.md
index 1fc9414940..d4a03c09fd 100644
--- a/docs/sections/how_to_guides/advanced/caching.md
+++ b/docs/sections/how_to_guides/advanced/caching.md
@@ -1,135 +1,60 @@
-# Cache and recover pipeline executions
+# Pipeline cache
-Distilabel `Pipelines` automatically save all the intermediate steps to avoid losing any data in case of error.
+`distilabel` will automatically save all the intermediate outputs generated by each [`Step`][distilabel.steps.base.Step] of a [`Pipeline`][distilabel.pipeline.local.Pipeline], so these outputs can be reused to recover the state of a pipeline execution that was stopped before finishing or to not have to re-execute steps from a pipeline after adding a new downstream step.
-## Cache directory
+## How to enable/disable the cache
-Out of the box, the `Pipeline` will use the `~/.cache/distilabel/pipelines` directory to store the different pipelines[^1]:
+The use of the cache can be toggled using the `use_cache` parameter of the [`Pipeline.use_cache`][distilabel.pipeline.base.BasePipeline.run] method. If `True`, then `distilabel ` will use the reuse the outputs of previous executions for the new execution. If `False`, then `distilabel` will re-execute all the steps of the pipeline to generate new outputs for all the steps.
```python
-from distilabel.pipeline.local import Pipeline
-
-with Pipeline(name="cache_testing") as pipeline:
+with Pipeline(name="my-pipeline") as pipeline:
...
-```
-This directory can be modified by setting the `DISTILABEL_CACHE_DIR` environment variable (`export DISTILABEL_CACHE_DIR=my_cache_dir`) or by explicitly passing the `cache_dir` variable to the `Pipeline` constructor like so:
-
-```python
-with Pipeline(name="cache_testing", cache_dir="~/my_cache_dir") as pipeline:
- ...
+if __name__ == "__main__":
+ distiset = pipeline.run(use_cache=False) # (1)
```
-[^1]:
-
- The pipelines will be organized according to the pipeline's name attribute, and then by the hash, in case you want to look for something manually, like the following example:
-
- ```bash
- $ tree ~/.cache/distilabel/pipelines/
- ├── cache_testing
- │ └── 13da04d2cc255b2180d6bebb50fb5be91124f70d
- │ ├── batch_manager.json
- │ ├── batch_manager_steps
- │ │ └── succeed_always_0.json
- │ ├── data
- │ │ └── succeed_always_0
- │ │ └── 00001.parquet
- │ ├── pipeline.log
- │ └── pipeline.yaml
- └── test-pipe
- └── f23b95d7ad4e9301a70b2a54c953f8375ebfcd5c
- ├── batch_manager.json
- ├── batch_manager_steps
- │ └── text_generation_0.json
- ├── data
- │ └── text_generation_0
- │ └── 00001.parquet
- ├── pipeline.log
- └── pipeline.yaml
- ```
-
-## How does it work?
-
-Let's take a look at the logging messages from a sample pipeline.
-
-When we run a `Pipeline` for the first time
-
-![Pipeline 1](../../../assets/images/sections/caching/caching_pipe_1.png)
-
-If we decide to stop the pipeline (say we kill the run altogether via `CTRL + C` or `CMD + C` in *macOS*), we will see the signal sent to the different workers:
-
-![Pipeline 2](../../../assets/images/sections/caching/caching_pipe_2.png)
-
-After this step, when we run again the pipeline, the first log message we see corresponds to "Load pipeline from cache", which will restart processing from where it stopped:
-
-![Pipeline 3](../../../assets/images/sections/caching/caching_pipe_3.png)
+1. Pipeline cache is disabled
-Finally, if we decide to run the same `Pipeline` after it has finished completely, it won't start again but resume the process, as we already have all the data processed:
+In addition, the cache can be enabled/disabled at [`Step`][distilabel.steps.base.Step] level using its `use_cache` attribute. If `True`, then the outputs of the step will be reused in the new pipeline execution. If `False`, then the step will be re-executed to generate new outputs. If the cache of one step is disabled and the outputs have to be regenerated, then the outputs of the steps that depend on this step will also be regenerated.
-![Pipeline 4](../../../assets/images/sections/caching/caching_pipe_4.png)
-
-### Serialization
-
-Let's see what gets serialized by looking at a sample `Pipeline`'s cached folder:
-
-```bash
-$ tree ~/.cache/distilabel/pipelines/73ca3f6b7a613fb9694db7631cc038d379f1f533
-├── batch_manager.json
-├── batch_manager_steps
-│ ├── generate_response.json
-│ └── rename_columns.json
-├── data
-│ └── generate_response
-│ ├── 00001.parquet
-│ └── 00002.parquet
-└── pipeline.yaml
+```python
+with Pipeline(name="writting-assistant") as pipeline:
+ load_data = LoadDataFromDicts(
+ data=[
+ {
+ "instruction": "How much is 2+2?"
+ }
+ ]
+ )
+
+ generation = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="Qwen/Qwen2.5-72B-Instruct",
+ generation_kwargs={
+ "temperature": 0.8,
+ "max_new_tokens": 512,
+ },
+ ),
+ use_cache=False # (1)
+ )
+
+ load_data >> generation
+
+if __name__ == "__main__":
+ distiset = pipeline.run()
```
-The `Pipeline` will have a signature created from the arguments that define it so we can find it afterwards, and the contents are the following:
-
-- `batch_manager.json`
-
- Folder that stores the content of the internal batch manager to keep track of the data. Along with the `batch_manager_steps/` they store the information to restart the `Pipeline`. One shouldn't need to know about it.
-
-- `pipeline.yaml`
-
- This file contains a representation of the `Pipeline` in *YAML* format. If we push a `Distiset` to the Hugging Face Hub as obtained from calling `Pipeline.run`, this file will be stored at our datasets' repository, allowing to reproduce the `Pipeline` using the `CLI`:
+1. Step cache is disabled and every time the pipeline is executed, this step will be re-executed
- ```bash
- distilabel pipeline run --config "path/to/pipeline.yaml"
- ```
+## How a cache hit is triggered
-- `data/`
+`distilabel` groups information and data generated by a `Pipeline` using the name of the pipeline, so the first factor that triggers a cache hit is the name of the pipeline. The second factor, is the [`Pipeline.signature`][distilabel.pipeline.local.Pipeline.signature] property. This property returns a hash that is generated using the names of the steps used in the pipeline and their connections. The third factor, is the [`Pipeline.aggregated_steps_signature`][distilabel.pipeline.local.Pipeline.aggregated_steps_signature] property which is used to determine if the new pipeline execution is exactly the same as one of the previous i.e. the pipeline contains exactly the same steps, with exactly the same connections and the steps are using exactly the same parameters. If these three factors are met, then the cache hit is triggered and the pipeline won't get re-executed and instead the function [`create_distiset`][distilabel.distiset.create_distiset] will be used to create the resulting [`Distiset`][distilabel.distiset.Distiset] using the outputs of the previous execution, as it can be seen in the following image:
- Folder that stores the data generated, with a special folder to keep track of each `leaf_step` separately. We can recreate a `Distiset` from the contents of this folder (*Parquet* files), as we will see next.
-
-- `pipeline.log`
-
- This file stores the logs that the `Pipeline` generated while processing. Just as with the `pipeline.yaml` file, it will be pushed to the Hugging Face Hub datasets` repository to keep track of the information.
-
-## create_distiset
-
-In case we wanted to regenerate the dataset from the `cache`, we can do it using the [`create_distiset`][distilabel.distiset.create_distiset] function and passing the path to the `/data` folder inside our `Pipeline`:
-
-```python
-from pathlib import Path
-from distilabel.distiset import create_distiset
-
-path = Path("~/.cache/distilabel/pipelines/73ca3f6b7a613fb9694db7631cc038d379f1f533/data")
-ds = create_distiset(path)
-ds
-# Distiset({
-# generate_response: DatasetDict({
-# train: Dataset({
-# features: ['instruction', 'response'],
-# num_rows: 80
-# })
-# })
-# })
-```
+![Complete cache hit](../../../assets/images/sections/caching/caching_1.png)
-!!! Note
+If the new pipeline execution have a different `Pipeline.aggregated_steps_signature` i.e. at least one step has changed its parameters, `distilabel` will reuse the outputs of the steps that have not changed and re-execute the steps that have changed, as it can be seen in the following image:
- Internally, the function will try to inject the `pipeline_path` variable if it's not passed via argument, assuming it's in the parent directory of the current one, called `pipeline.yaml`. If the file doesn't exist, it won't raise any error, but take into account that if the `Distiset` is pushed to the Hugging Face Hub, the `pipeline.yaml` won't be generated. The same happens with the `pipeline.log` file, it can be passed via `log_filename_path`, but it will try to locate it automatically.
+![Partial cache hit](../../../assets/images/sections/caching/caching_2.png)
- Lastly, there is the option of including the `distilabel_metadata` column in the final dataset. This column can contain custom metadata generated automatically by the pipeline, like the raw output from an `LLM` without formatting in case of failure, and we can decide whether to include it using the `enable_metadata` argument.
+The same pipeline from above gets executed a third time, but this time the last step `text_generation_1` changed, so it's needed to re-execute it. The other steps, as they have not been changed, doesn't need to be re-executed and their outputs are reused.
diff --git a/docs/sections/how_to_guides/advanced/offline_batch_generation.md b/docs/sections/how_to_guides/advanced/offline_batch_generation.md
new file mode 100644
index 0000000000..b45ad1d716
--- /dev/null
+++ b/docs/sections/how_to_guides/advanced/offline_batch_generation.md
@@ -0,0 +1,47 @@
+The [offline batch generation](../basic/llm/index.md#offline-batch-generation) is a feature that some `LLM`s implemented in `distilabel` offers, allowing to send the inputs to a LLM-as-a-service platform and waiting for the outputs in a asynchronous manner. LLM-as-a-service platforms offer this feature as it allows them to gather many inputs and creating batches as big as the hardware allows, maximizing the hardware utilization and reducing the cost of the service. In exchange, the user has to wait certain time for the outputs to be ready but the cost per token is usually much lower.
+
+`distilabel` pipelines are able to handle `LLM`s that offer this feature in the following way:
+
+* The first time the pipeline gets executed, the `LLM` will send the inputs to the platform. The platform will return jobs ids that can be used later to check the status of the jobs and retrieve the results. The `LLM` will save these jobs ids in its `jobs_ids` attribute and raise an special exception [DistilabelOfflineBatchGenerationNotFinishedException][distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException] that will be handled by the `Pipeline`. The jobs ids will be saved in the pipeline cache, so they can be used in subsequent calls.
+* The second time and subsequent calls will recover the pipeline execution and the `LLM` won't send the inputs again to the platform. This time as it has the `jobs_ids` it will check if the jobs have finished, and if they have then it will retrieve the results and return the outputs. If they haven't finished, then it will raise again `DistilabelOfflineBatchGenerationNotFinishedException` again.
+* In addition, LLMs with offline batch generation can be specified to do polling until the jobs have finished, blocking the pipeline until they are done. If for some reason the polling needs to be stopped, one can press ++ctrl+c++ or ++cmd+c++ depending on your OS (or send a `SIGINT` to the main process) which will stop the polling and raise `DistilabelOfflineBatchGenerationNotFinishedException` that will be handled by the pipeline as described above.
+
+!!! WARNING
+
+ In order to recover the pipeline execution and retrieve the results, the pipeline cache must be enabled. If the pipeline cache is disabled, then it will send the inputs again and create different jobs incurring in extra costs.
+
+
+## Example pipeline using `OpenAILLM` with offline batch generation
+
+```python
+from distilabel.llms import OpenAILLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import LoadDataFromHub
+from distilabel.steps.tasks import TextGeneration
+
+with Pipeline() as pipeline:
+ load_data = LoadDataFromHub(output_mappings={"prompt": "instruction"})
+
+ text_generation = TextGeneration(
+ llm=OpenAILLM(
+ model="gpt-3.5-turbo",
+ use_offline_batch_generation=True, # (1)
+ )
+ )
+
+ load_data >> text_generation
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run(
+ parameters={
+ load_data.name: {
+ "repo_id": "distilabel-internal-testing/instruction-dataset",
+ "split": "test",
+ "batch_size": 500,
+ },
+ }
+ )
+```
+
+1. Indicate that the `OpenAILLM` should use offline batch generation.
diff --git a/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md b/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md
new file mode 100644
index 0000000000..3d2e566047
--- /dev/null
+++ b/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md
@@ -0,0 +1,123 @@
+# Saving step generated artifacts
+
+Some `Step`s might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the [`FaissNearestNeighbour`](../../../components-gallery/steps/faissnearestneighbour.md) needs to create a Faiss index to compute the output of the step which are the top `k` nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the `distilabel` pipeline, so it would be a shame not saving it.
+
+For this reason, `Step`s have a method called `save_artifact` that allows saving artifacts that will be included along the outputs of the pipeline in the generated [`Distiset`][distilabel.distiset.Distiset]. The generated artifacts will be uploaded and saved when using `Distiset.push_to_hub` or `Distiset.save_to_disk` respectively. Let's see how to use it with a simple example.
+
+```python
+from typing import List, TYPE_CHECKING
+from distilabel.steps import GlobalStep, StepInput, StepOutput
+import matplotlib.pyplot as plt
+
+if TYPE_CHECKING:
+ from distilabel.steps import StepOutput
+
+
+class CountTextCharacters(GlobalStep):
+ @property
+ def inputs(self) -> List[str]:
+ return ["text"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["text_character_count"]
+
+ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
+ character_counts = []
+
+ for input in inputs:
+ text_character_count = len(input["text"])
+ input["text_character_count"] = text_character_count
+ character_counts.append(text_character_count)
+
+ # Generate plot with the distribution of text character counts
+ plt.figure(figsize=(10, 6))
+ plt.hist(character_counts, bins=30, edgecolor="black")
+ plt.title("Distribution of Text Character Counts")
+ plt.xlabel("Character Count")
+ plt.ylabel("Frequency")
+
+ # Save the plot as an artifact of the step
+ self.save_artifact(
+ name="text_character_count_distribution",
+ write_function=lambda path: plt.savefig(path / "figure.png"),
+ metadata={"type": "image", "library": "matplotlib"},
+ )
+
+ plt.close()
+
+ yield inputs
+```
+
+As it can be seen in the example above, we have created a simple step that counts the number of characters in each input text and generates a histogram with the distribution of the character counts. We save the histogram as an artifact of the step using the `save_artifact` method. The method takes three arguments:
+
+- `name`: The name we want to give to the artifact.
+- `write_function`: A function that writes the artifact to the desired path. The function will receive a `path` argument which is a `pathlib.Path` object pointing to the directory where the artifact should be saved.
+- `metadata`: A dictionary with metadata about the artifact. This metadata will be saved along with the artifact.
+
+Let's execute the step with a simple pipeline and push the resulting `Distiset` to the Hugging Face Hub:
+
+??? "Example full code"
+
+ ```python
+ from typing import TYPE_CHECKING, List
+
+ import matplotlib.pyplot as plt
+ from datasets import load_dataset
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps import GlobalStep, StepInput, StepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps import StepOutput
+
+
+ class CountTextCharacters(GlobalStep):
+ @property
+ def inputs(self) -> List[str]:
+ return ["text"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["text_character_count"]
+
+ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
+ character_counts = []
+
+ for input in inputs:
+ text_character_count = len(input["text"])
+ input["text_character_count"] = text_character_count
+ character_counts.append(text_character_count)
+
+ # Generate plot with the distribution of text character counts
+ plt.figure(figsize=(10, 6))
+ plt.hist(character_counts, bins=30, edgecolor="black")
+ plt.title("Distribution of Text Character Counts")
+ plt.xlabel("Character Count")
+ plt.ylabel("Frequency")
+
+ # Save the plot as an artifact of the step
+ self.save_artifact(
+ name="text_character_count_distribution",
+ write_function=lambda path: plt.savefig(path / "figure.png"),
+ metadata={"type": "image", "library": "matplotlib"},
+ )
+
+ plt.close()
+
+ yield inputs
+
+
+ with Pipeline() as pipeline:
+ count_text_characters = CountTextCharacters()
+
+ if __name__ == "__main__":
+ distiset = pipeline.run(
+ dataset=load_dataset(
+ "HuggingFaceH4/instruction-dataset", split="test"
+ ).rename_column("prompt", "text"),
+ )
+
+ distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example")
+ ```
+
+The generated [distilabel-internal-testing/distilabel-artifacts-example](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example) dataset repository has a section in its card [describing the artifacts generated by the pipeline](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example#artifacts) and the generated plot can be seen [here](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example/blob/main/artifacts/count_text_characters_0/text_character_count_distribution/figure.png).
diff --git a/docs/sections/how_to_guides/advanced/scaling_with_ray.md b/docs/sections/how_to_guides/advanced/scaling_with_ray.md
index 4a8b480126..be959c8b72 100644
--- a/docs/sections/how_to_guides/advanced/scaling_with_ray.md
+++ b/docs/sections/how_to_guides/advanced/scaling_with_ray.md
@@ -85,7 +85,7 @@ if __name__ == "__main__":
1. We're setting [resources](assigning_resources_to_step.md) for the `text_generation` step and defining that we want two replicas and one GPU per replica. `distilabel` will create two replicas of the step i.e. two actors in the Ray cluster, and each actor will request to be allocated in a node of the cluster that have at least one GPU. You can read more about how Ray manages the resources [here](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#resources).
2. You should modify this and add your user or organization on the Hugging Face Hub.
-It's a basic pipeline with just two steps: one to load a dataset from the Hub with an `instruction` column and one to generate a `response` for that instruction using Llama 3 8B Instruct with [vLLM](/distilabel/components-gallery/llms/vllm/). Simple but enough to demonstrate how to distribute and scale the workload using a Ray cluster!
+It's a basic pipeline with just two steps: one to load a dataset from the Hub with an `instruction` column and one to generate a `response` for that instruction using Llama 3 8B Instruct with [vLLM](../../../components-gallery/llms/vllm.md). Simple but enough to demonstrate how to distribute and scale the workload using a Ray cluster!
### Using Ray Jobs API
diff --git a/docs/sections/how_to_guides/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md
index d3e750aa93..6f907951c1 100644
--- a/docs/sections/how_to_guides/advanced/structured_generation.md
+++ b/docs/sections/how_to_guides/advanced/structured_generation.md
@@ -111,7 +111,7 @@ These were some simple examples, but one can see the options this opens.
!!! Tip
A full pipeline example can be seen in the following script:
- [`examples/structured_generation_with_outlines.py`](../../pipeline_samples/examples/index.md#llamacpp-with-outlines)
+ [`examples/structured_generation_with_outlines.py`](../../pipeline_samples/examples/llama_cpp_with_outlines.md)
[^1]:
You can check the variable type by importing it from:
@@ -129,7 +129,7 @@ These were some simple examples, but one can see the options this opens.
## Instructor
-When working with model providers behind an API, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM], so you can work with the following LLMs: [`OpenAILLM`][distilabel.llms.OpenAILLM], [`AzureOpenAILLM`][distilabel.llms.AzureOpenAILLM], [`CohereLLM`][distilabel.llms.CohereLLM], [`GroqLLM`][distilabel.llms.GroqLLM], [`LiteLLM`][distilabel.llms.LiteLLM] and [`MistralLLM`][distilabel.llms.MistralLLM].
+For other LLM providers behind APIs, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM].
!!! Note
For `instructor` integration to work you may need to install the corresponding dependencies:
@@ -155,14 +155,15 @@ class User(BaseModel):
And then we provide that schema to the `structured_output` argument of the LLM:
-!!! Note
- In this example we are using *open-mixtral-8x22b*, keep in mind not all the models work with the function calling functionality required for this example to work.
+!!! NOTE
+ In this example we are using *Meta Llama 3.1 8B Instruct*, keep in mind not all the models support structured outputs.
```python
from distilabel.llms import MistralLLM
-llm = MistralLLM(
- model="open-mixtral-8x22b",
+llm = InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
structured_output={"schema": User}
)
llm.load()
@@ -189,7 +190,7 @@ We get back a Python dictionary (formatted as a string) that we can parse using
!!! Tip
A full pipeline example can be seen in the following script:
- [`examples/structured_generation_with_instructor.py`](../../pipeline_samples/examples/index.md#mistralai-with-instructor)
+ [`examples/structured_generation_with_instructor.py`](../../pipeline_samples/examples/mistralai_with_instructor.md)
## OpenAI JSON
diff --git a/docs/sections/how_to_guides/basic/llm/index.md b/docs/sections/how_to_guides/basic/llm/index.md
index 4bd5f9de2b..f9dec754ae 100644
--- a/docs/sections/how_to_guides/basic/llm/index.md
+++ b/docs/sections/how_to_guides/basic/llm/index.md
@@ -1,16 +1,16 @@
-# Define LLMs as local or remote models
+# Executing Tasks with LLMs
## Working with LLMs
LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Task], but they can also be used standalone.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.llms import InferenceEndpointsLLM
-llm = OpenAILLM(model="gpt-4")
+llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct")
llm.load()
-llm.generate(
+llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
@@ -21,6 +21,69 @@ llm.generate(
!!! NOTE
Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`.
+### Offline Batch Generation
+
+By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.
+
+To use this feature in `distilabel` the only thing we need to do is to set the `use_offline_batch_generation` attribute to `True` when creating the `LLM` instance:
+
+```python
+from distilabel.llms import OpenAILLM
+
+llm = OpenAILLM(
+ model="gpt-4o",
+ use_offline_batch_generation=True,
+)
+
+llm.load()
+
+llm.jobs_ids # (1)
+# None
+
+llm.generate_outputs( # (2)
+ inputs=[
+ [{"role": "user", "content": "What's the capital of Spain?"}],
+ ],
+)
+# DistilabelOfflineBatchGenerationNotFinishedException: Batch generation with jobs_ids=('batch_OGB4VjKpu2ay9nz3iiFJxt5H',) is not finished
+
+llm.jobs_ids # (3)
+# ('batch_OGB4VjKpu2ay9nz3iiFJxt5H',)
+
+
+llm.generate_outputs( # (4)
+ inputs=[
+ [{"role": "user", "content": "What's the capital of Spain?"}],
+ ],
+)
+# "The capital of Spain is Madrid."
+```
+
+1. At first the `jobs_ids` attribute is `None`.
+2. The first call to `generate_outputs` will send the inputs to the LLM-as-a-service and return a `DistilabelOfflineBatchGenerationNotFinishedException` since the outputs are not ready yet.
+3. After the first call to `generate_outputs` the `jobs_ids` attribute will contain the job ids created for generating the outputs.
+4. The second call or subsequent calls to `generate_outputs` will return the outputs if they are ready or raise a `DistilabelOfflineBatchGenerationNotFinishedException` if they are not ready yet.
+
+The `offline_batch_generation_block_until_done` attribute can be used to block the `generate_outputs` method until the outputs are ready polling the platform the specified amount of seconds.
+
+```python
+from distilabel.llms import OpenAILLM
+
+llm = OpenAILLM(
+ model="gpt-4o",
+ use_offline_batch_generation=True,
+ offline_batch_generation_block_until_done=5, # poll for results every 5 seconds
+)
+llm.load()
+
+llm.generate_outputs(
+ inputs=[
+ [{"role": "user", "content": "What's the capital of Spain?"}],
+ ],
+)
+# "The capital of Spain is Madrid."
+```
+
### Within a Task
Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and the task will handle the rest.
@@ -81,7 +144,7 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
* `generate`: A method that takes a list of prompts and returns generated texts.
* `agenerate`: A method that takes a single prompt and returns generated texts. This method is used within the `generate` method of the `AsyncLLM` class.
-*
+
* (optional) `get_last_hidden_state`: is a method that will take a list of prompts and return a list of hidden states. This method is optional and will be used by some tasks such as the [`GenerateEmbeddings`][distilabel.steps.tasks.GenerateEmbeddings] task.
@@ -142,4 +205,4 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
## Available LLMs
-[Our LLM gallery](/distilabel/components-gallery/llms/) shows a list of the available LLMs that can be used within the `distilabel` library.
\ No newline at end of file
+[Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library.
diff --git a/docs/sections/how_to_guides/basic/pipeline/index.md b/docs/sections/how_to_guides/basic/pipeline/index.md
index 2d03f9ea87..f592082191 100644
--- a/docs/sections/how_to_guides/basic/pipeline/index.md
+++ b/docs/sections/how_to_guides/basic/pipeline/index.md
@@ -421,7 +421,7 @@ with Pipeline("pipe-name", description="My first pipe") as pipeline:
VertexAILLM(model="gemini-1.5-pro"),
):
task = TextGeneration(
- name=f"text_generation_with_{llm.model_name}",
+ name=f"text_generation_with_{llm.model_name.replace('.', '-')}",
llm=llm,
input_batch_size=5,
)
@@ -459,6 +459,30 @@ To load the pipeline, we can use the `from_yaml` or `from_json` methods:
Serializing the pipeline is very useful when we want to share the pipeline with others, or when we want to store the pipeline for future use. It can even be hosted online, so the pipeline can be executed directly using the [CLI](../../advanced/cli/index.md).
+## Visualizing the pipeline
+
+We can visualize the pipeline using the `Pipeline.draw()` method. This will create a `mermaid` graph, and return the path to the image.
+
+```python
+path_to_image = pipeline.draw(
+ top_to_bottom=True,
+ show_edge_labels=True,
+)
+```
+
+Within notebooks, we can simply call `pipeline` and the graph will be displayed. Alternatively, we can use the `Pipeline.draw()` method to have more control over the graph visualization and use `IPython` to display it.
+
+```python
+from IPython.display import Image, display
+
+display(Image(path_to_image))
+```
+
+Let's now see how the pipeline of the [fully working example](#fully-working-example) looks like.
+
+![Pipeline](../../../../assets/images/sections/how_to_guides/basic/pipeline.png)
+
+
## Fully working example
To sum up, here is the full code of the pipeline we have created in this section. Note that you will need to change the name of the Hugging Face repository where the resulting will be pushed, set `OPENAI_API_KEY` environment variable, set `MISTRAL_API_KEY` and have `gcloud` installed and configured:
@@ -487,7 +511,9 @@ To sum up, here is the full code of the pipeline we have created in this section
MistralLLM(model="mistral-large-2402"),
VertexAILLM(model="gemini-1.0-pro"),
):
- task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
+ task = TextGeneration(
+ name=f"text_generation_with_{llm.model_name.replace('.', '-')}", llm=llm
+ )
load_dataset.connect(task)
task.connect(combine_generations)
diff --git a/docs/sections/how_to_guides/basic/step/generator_step.md b/docs/sections/how_to_guides/basic/step/generator_step.md
index c5b665a82e..0422644c36 100644
--- a/docs/sections/how_to_guides/basic/step/generator_step.md
+++ b/docs/sections/how_to_guides/basic/step/generator_step.md
@@ -3,17 +3,19 @@
The [`GeneratorStep`][distilabel.steps.GeneratorStep] is a subclass of [`Step`][distilabel.steps.Step] that is intended to be used as the first step within a [`Pipeline`][distilabel.pipeline.Pipeline], because it doesn't require input and generates data that can be used by other steps. Alternatively, it can also be used as a standalone.
```python
-from typing import List
+from typing import List, TYPE_CHECKING
from typing_extensions import override
from distilabel.steps import GeneratorStep
-from distilabel.steps.typing import GeneratorStepOutput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns, GeneratorStepOutput
class MyGeneratorStep(GeneratorStep):
instructions: List[str]
@override
- def process(self, offset: int = 0) -> GeneratorStepOutput:
+ def process(self, offset: int = 0) -> "GeneratorStepOutput":
if offset:
self.instructions = self.instructions[offset:]
@@ -30,7 +32,7 @@ class MyGeneratorStep(GeneratorStep):
)
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return ["instruction"]
```
@@ -57,7 +59,7 @@ next(step.process(offset=1))
We can define a custom generator step by creating a new subclass of the [`GeneratorStep`][distilabel.steps.GeneratorStep] and defining the following:
-- `outputs`: is a property that returns a list of strings with the names of the output fields.
+- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
- `process`: is a method that yields output data and a boolean flag indicating whether that's the last batch to be generated.
@@ -73,21 +75,23 @@ We can define a custom generator step by creating a new subclass of the [`Genera
```python
- from typing import List
+ from typing import List, TYPE_CHECKING
from typing_extensions import override
from distilabel.steps import GeneratorStep
- from distilabel.steps.typing import GeneratorStepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns, GeneratorStepOutput
class MyGeneratorStep(GeneratorStep):
instructions: List[str]
@override
- def process(self, offset: int = 0) -> GeneratorStepOutput:
+ def process(self, offset: int = 0) -> "GeneratorStepOutput":
...
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
...
```
@@ -96,15 +100,18 @@ We can define a custom generator step by creating a new subclass of the [`Genera
The `@step` decorator will take care of the boilerplate code, and will allow to define the `outputs`, and `process` methods in a more straightforward way. One downside is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GeneratorStep`][distilabel.steps.GeneratorStep] subclass.
```python
+ from typing import TYPE_CHECKING
from distilabel.steps import step
- from distilabel.steps.typing import GeneratorStepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import GeneratorStepOutput
@step(outputs=[...], step_type="generator")
- def CustomGeneratorStep(offset: int = 0) -> GeneratorStepOutput:
+ def CustomGeneratorStep(offset: int = 0) -> "GeneratorStepOutput":
yield (
...,
True if offset == 10 else False,
)
step = CustomGeneratorStep(name="my-step")
- ```
\ No newline at end of file
+ ```
diff --git a/docs/sections/how_to_guides/basic/step/global_step.md b/docs/sections/how_to_guides/basic/step/global_step.md
index c9044a87d2..814f01a0fb 100644
--- a/docs/sections/how_to_guides/basic/step/global_step.md
+++ b/docs/sections/how_to_guides/basic/step/global_step.md
@@ -6,9 +6,9 @@ The [`GlobalStep`][distilabel.steps.GlobalStep] is a subclass of [`Step`][distil
We can define a custom step by creating a new subclass of the [`GlobalStep`][distilabel.steps.GlobalStep] and defining the following:
-- `inputs`: is a property that returns a list of strings with the names of the required input fields.
+- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
-- `outputs`: is a property that returns a list of strings with the names of the output fields.
+- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data.
@@ -23,20 +23,23 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis
We can inherit from the `GlobalStep` class and define the `inputs`, `outputs`, and `process` methods as follows:
```python
+ from typing import TYPE_CHECKING
from distilabel.steps import GlobalStep, StepInput
- from distilabel.steps.typing import StepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns, StepOutput
class CustomStep(Step):
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
...
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
...
def process(self, *inputs: StepInput) -> StepOutput:
- for input in inputs:
+ for upstream_step_inputs in inputs:
for item in input:
...
yield item
@@ -54,14 +57,17 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis
The `@step` decorator will take care of the boilerplate code, and will allow to define the `inputs`, `outputs`, and `process` methods in a more straightforward way. One downside is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclass.
```python
+ from typing import TYPE_CHECKING
from distilabel.steps import StepInput, step
- from distilabel.steps.typing import StepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
@step(inputs=[...], outputs=[...], step_type="global")
- def CustomStep(inputs: StepInput) -> StepOutput:
+ def CustomStep(inputs: StepInput) -> "StepOutput":
for input in inputs:
...
yield inputs
step = CustomStep(name="my-step")
- ```
\ No newline at end of file
+ ```
diff --git a/docs/sections/how_to_guides/basic/step/index.md b/docs/sections/how_to_guides/basic/step/index.md
index e3b19e5334..18388b8f4a 100644
--- a/docs/sections/how_to_guides/basic/step/index.md
+++ b/docs/sections/how_to_guides/basic/step/index.md
@@ -1,4 +1,4 @@
-# Define Steps for your Pipeline
+# Steps for processing data
## Working with Steps
@@ -7,13 +7,19 @@ The [`Step`][distilabel.steps.Step] is intended to be used within the scope of a
Assuming that we have a [`Step`][distilabel.steps.Step] already defined as it follows:
```python
+from typing import TYPE_CHECKING
+from distilabel.steps import Step, StepInput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns, StepOutput
+
class MyStep(Step):
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
return ["input_field"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return ["output_field"]
def process(self, inputs: StepInput) -> "StepOutput":
@@ -44,7 +50,7 @@ next(step.process([{"input_field": "value"}]))
### Runtime parameters
-`Step`s can also have `RuntimeParameter`, which are parameters that can only used after the pipeline initialisation when calling the `Pipeline.run`.
+`Step`s can also have `RuntimeParameter`, which are parameters that can only be used after the pipeline initialisation when calling the `Pipeline.run`.
```python
from distilabel.mixins.runtime_parameters import RuntimeParameter
@@ -71,9 +77,9 @@ There are two special types of [`Step`][distilabel.steps.Step] in `distilabel`:
We can define a custom step by creating a new subclass of the [`Step`][distilabel.steps.Step] and defining the following:
-- `inputs`: is a property that returns a list of strings with the names of the required input fields.
+- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
-- `outputs`: is a property that returns a list of strings with the names of the output fields.
+- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data.
@@ -88,20 +94,23 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe
We can inherit from the `Step` class and define the `inputs`, `outputs`, and `process` methods as follows:
```python
+ from typing import TYPE_CHECKING
from distilabel.steps import Step, StepInput
- from distilabel.steps.typing import StepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns, StepOutput
class CustomStep(Step):
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
...
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
...
- def process(self, *inputs: StepInput) -> StepOutput:
- for input in inputs:
+ def process(self, *inputs: StepInput) -> "StepOutput":
+ for upstream_step_inputs in inputs:
...
yield item
@@ -119,14 +128,17 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe
```python
+ from typing import TYPE_CHECKING
from distilabel.steps import StepInput, step
- from distilabel.steps.typing import StepOutput
+
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
@step(inputs=[...], outputs=[...])
- def CustomStep(inputs: StepInput) -> StepOutput:
+ def CustomStep(inputs: StepInput) -> "StepOutput":
for input in inputs:
...
yield inputs
step = CustomStep(name="my-step")
- ```
\ No newline at end of file
+ ```
diff --git a/docs/sections/how_to_guides/basic/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md
index 040af877d9..613d8deb17 100644
--- a/docs/sections/how_to_guides/basic/task/generator_task.md
+++ b/docs/sections/how_to_guides/basic/task/generator_task.md
@@ -1,4 +1,4 @@
-# GeneratorTask
+# GeneratorTask that produces output
## Working with GeneratorTasks
diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md
index 70b118c3ea..7aa2049f4b 100644
--- a/docs/sections/how_to_guides/basic/task/index.md
+++ b/docs/sections/how_to_guides/basic/task/index.md
@@ -1,4 +1,4 @@
-# Define Tasks that rely on LLMs
+# Tasks for generating and judging with LLMs
## Working with Tasks
@@ -24,7 +24,12 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
# {
# 'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
-# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'},
+# 'distilabel_metadata': {
+# 'raw_output_text-generation': 'The capital of Spain is Madrid.',
+# 'raw_input_text-generation': [
+# {'role': 'user', 'content': "What's the capital of Spain?"}
+# ]
+# },
# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct'
# }
# ]
@@ -33,7 +38,93 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
!!! NOTE
The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution.
-As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`. Additionally, it provides some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.
+As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`.
+
+!!! Tip
+ Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.
+
+ Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing
+ custom templates (testing the pipeline using the [`dry_run`][distilabel.pipeline.local.Pipeline.dry_run] method).
+
+ ```python title="disable raw input and output"
+ task = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ add_raw_output=False,
+ add_raw_input=False
+ )
+ ```
+
+### Task.print
+
+!!! Info
+ New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method.
+
+The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`.
+
+```python
+from distilabel.steps.tasks import UltraFeedback
+from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+uf = UltraFeedback(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+)
+uf.load()
+uf.print()
+```
+
+The result will be a rendered prompt, with the System prompt (if contained for the task) and the User prompt, rendered with rich (it will show exactly the same in a jupyter notebook).
+
+![task-print](../../../../assets/images/sections/how_to_guides/tasks/task_print.png)
+
+In case you want to test with a custom input, you can pass an example to the tasks` `format_input` method (or generate it on your own depending on the task), and pass it to the print method so that it shows your example:
+
+
+```python
+uf.print(
+ uf.format_input({"instruction": "test", "generations": ["1", "2"]})
+)
+```
+
+??? "Using a DummyLLM to avoid loading one"
+
+ In case you don't want to load an LLM to render the template, you can create a dummy one like the ones we could use for testing.
+
+ ```python
+ from distilabel.llms.base import LLM
+ from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+
+ class DummyLLM(AsyncLLM, MagpieChatTemplateMixin):
+ structured_output: Any = None
+ magpie_pre_query_template: str = "llama3"
+
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate(
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+ ```
+
+ You can use this `LLM` just as any of the other ones to `load` your task and call `print`:
+
+ ```python
+ uf = UltraFeedback(llm=DummyLLM())
+ uf.load()
+ uf.print()
+ ```
+
+!!! Note
+ When creating a custom task, the `print` method will be available by default, but it is limited to the most common scenarios for the inputs. If you test your new task and find it's not working as expected (for example, if your task contains one input consisting of a list of texts instead of a single one), you should override the `_sample_input` method. You can inspect the [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback] source code for this.
## Specifying the number of generations and grouping generations
@@ -112,27 +203,30 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
We can define a custom step by creating a new subclass of the [`Task`][distilabel.steps.tasks.Task] and defining the following:
-- `inputs`: is a property that returns a list of strings with the names of the required input fields.
+- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.
- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.steps.tasks.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation).
-- `outputs`: is a property that returns a list of strings with the names of the output fields, this property should always include `model_name` as one of the outputs since that's automatically injected from the LLM.
+- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM.
- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
```python
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Union, TYPE_CHECKING
from distilabel.steps.tasks.base import Task
-from distilabel.steps.tasks.typing import ChatType
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns
+ from distilabel.steps.tasks.typing import ChatType
class MyCustomTask(Task):
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
return ["input_field"]
- def format_input(self, input: Dict[str, Any]) -> ChatType:
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
return [
{
"role": "user",
@@ -141,7 +235,7 @@ class MyCustomTask(Task):
]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return ["output_field", "model_name"]
def format_output(
diff --git a/docs/sections/pipeline_samples/examples/benchmarking_with_distilabel.md b/docs/sections/pipeline_samples/examples/benchmarking_with_distilabel.md
new file mode 100644
index 0000000000..5f775f604e
--- /dev/null
+++ b/docs/sections/pipeline_samples/examples/benchmarking_with_distilabel.md
@@ -0,0 +1,22 @@
+---
+hide: toc
+---
+# Benchmarking with `distilabel`
+
+Benchmark LLMs with `distilabel`: reproducing the Arena Hard benchmark.
+
+The script below first defines both the `ArenaHard` and the `ArenaHardResults` tasks, so as to generate responses for a given collection of prompts/questions with up to two LLMs, and then calculate the results as per the original implementation, respectively. Additionally, the second part of the example builds a `Pipeline` to run the generation on top of the prompts with `InferenceEndpointsLLM` while streaming the rest of the generations from a pre-computed set of GPT-4 generations, and then evaluate one against the other with `OpenAILLM` generating an alternate response, a comparison between the responses, and a result as A>>B, A>B, B>A, B>>A, or tie.
+
+![Arena Hard](../../../assets/pipelines/arena-hard.png)
+
+To run this example you will first need to install the Arena Hard optional dependencies, being `pandas`, `scikit-learn`, and `numpy`.
+
+??? Run
+
+ ```python
+ python examples/arena_hard.py
+ ```
+
+```python title="arena_hard.py"
+--8<-- "examples/arena_hard.py"
+```
\ No newline at end of file
diff --git a/docs/sections/pipeline_samples/examples/fine_personas_social_network.md b/docs/sections/pipeline_samples/examples/fine_personas_social_network.md
new file mode 100644
index 0000000000..52df495fc4
--- /dev/null
+++ b/docs/sections/pipeline_samples/examples/fine_personas_social_network.md
@@ -0,0 +1,232 @@
+---
+hide: toc
+---
+
+# Create a social network with FinePersonas
+
+In this example, we'll explore the creation of specialized user personas for social network interactions using the [FinePersonas-v0.1](https://huggingface.co/datasets/argilla/FinePersonas-v0.1) dataset from Hugging Face. The final dataset will be ready to fine-tune a chat model with specific traits and characteristics.
+
+## Introduction
+
+We'll delve into the process of fine-tuning different LoRA (Low-Rank Adaptation) models to imbue these personas with specific traits and characteristics.
+
+This approach draws inspiration from Michael Sayman's work on [SocialAI](https://apps.apple.com/us/app/socialai-ai-social-network/id6670229993) (visit the [profile](https://x.com/michaelsayman) to see some examples), to leverage [FinePersonas-v0.1](https://huggingface.co/datasets/argilla/FinePersonas-v0.1) for building models that can emulate bots with specific behaviour.
+
+By fine-tuning these adapters, we can potentially create AI personas with distinct characteristics, communication styles, and areas of expertise. The result? AI interactions that feel more natural and tailored to specific contexts or user needs. For those interested in the technical aspects of this approach, we recommend the insightful blog post on [Multi-LoRA serving](https://huggingface.co/blog/multi-lora-serving). It provides a clear and comprehensive explanation of the technology behind this innovative method.
+
+Let's jump to the demo.
+
+## Creating our SocialAI Task
+
+Building on the new [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/), creating custom tasks is easier than ever before. This powerful tool opens up a world of possibilities for creating tailored text-based content with ease and precision. We will create a `SocialAI` task that will be in charge of generating responses to user interactions, taking into account a given `follower_type`, and use the perspective from a given `persona`:
+
+```python
+from distilabel.steps.tasks import TextGeneration
+
+class SocialAI(TextGeneration):
+ follower_type: Literal["supporter", "troll", "alarmist"] = "supporter"
+ system_prompt: str = (
+ "You are an AI assistant expert at simulating user interactions. "
+ "You must answer as if you were a '{follower_type}', be concise answer with no more than 200 characters, nothing else."
+ "Here are some traits to use for your personality:\n\n"
+ "{traits}"
+ ) # (1)
+ template: str = "You are the folowing persona:\n\n{{ persona }}\n\nWhat would you say to the following?\n\n {{ post }}" # (2)
+ columns: str | list[str] = ["persona", "post"] # (3)
+
+ _follower_traits: dict[str, str] = {
+ "supporter": (
+ "- Encouraging and positive\n"
+ "- Tends to prioritize enjoyment and relaxation\n"
+ "- Focuses on the present moment and short-term pleasure\n"
+ "- Often uses humor and playful language\n"
+ "- Wants to help others feel good and have fun\n"
+ ),
+ "troll": (
+ "- Provocative and confrontational\n"
+ "- Enjoys stirring up controversy and conflict\n"
+ "- Often uses sarcasm, irony, and mocking language\n"
+ "- Tends to belittle or dismiss others' opinions and feelings\n"
+ "- Seeks to get a rise out of others and create drama\n"
+ ),
+ "alarmist": (
+ "- Anxious and warning-oriented\n"
+ "- Focuses on potential risks and negative consequences\n"
+ "- Often uses dramatic or sensational language\n"
+ "- Tends to be serious and stern in tone\n"
+ "- Seeks to alert others to potential dangers and protect them from harm (even if it's excessive or unwarranted)\n"
+ ),
+ }
+
+ def load(self) -> None:
+ super().load()
+ self.system_prompt = self.system_prompt.format(
+ follower_type=self.follower_type,
+ traits=self._follower_traits[self.follower_type]
+ ) # (4)
+```
+
+1. We have a custom system prompt that will depend on the `follower_type` we decide for our model.
+
+2. The base template or prompt will answert to the `post` we have, from the point of view of a `persona`.
+
+3. We will need our dataset to have both `persona` and `post` columns to populate the prompt.
+
+4. In the load method we place the specific traits for our follower type in the system prompt.
+
+## Data preparation
+
+This is an example, so let's keep it short. We will use 3 posts, and 3 different types of personas. While there's potential to enhance this process (perhaps by implementing random persona selection or leveraging semantic similarity) we'll opt for a straightforward method in this demonstration.
+
+Our goal is to create a set of nine examples, each pairing a post with a persona. To achieve this, we'll employ an LLM to respond to each post from the perspective of a specific `persona`, effectively simulating how different characters might engage with the content.
+
+```python
+posts = [
+ {
+ "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?"
+ },
+ {
+ "post": "I need to develop a training course for my company on communication skills. Need to decide how deliver it remotely."
+ },
+ {
+ "post": "I'm always 10 minutes late to meetups but no one's complained. Could this be annoying to them?"
+ },
+]
+
+personas = (
+ load_dataset("argilla/FinePersonas-v0.1-clustering-100k", split="train")
+ .shuffle()
+ .select(range(3))
+ .select_columns("persona")
+ .to_list()
+)
+
+data = []
+for post in posts:
+ for persona in personas:
+ data.append({"post": post["post"], "persona": persona["persona"]})
+```
+
+Each row in will have the following format:
+
+```python
+import json
+print(json.dumps(data[0], indent=4))
+{
+ "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?",
+ "persona": "A high school or college environmental science teacher or an ecology student specializing in biogeography and ecosystem dynamics."
+}
+```
+
+This will be our dataset, that we can ingest using the [`LoadDataFromDicts`](https://distilabel.argilla.io/dev/components-gallery/steps/loaddatafromdicts/):
+
+```python
+loader = LoadDataFromDicts(data=data)
+```
+
+## Simulating from different types of followers
+
+With our data in hand, we're ready to explore the capabilities of our SocialAI task. For this demonstration, we'll make use of of `meta-llama/Meta-Llama-3.1-70B-Instruct`
+While this model has become something of a go-to choice recently, it's worth noting that experimenting with a variety of models could yield even more interesting results:
+
+```python
+from distilabel.llms import InferenceEndpointsLLM
+
+llm = InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 256,
+ },
+)
+follower_type = "supporter"
+
+follower = SocialAI(
+ llm=llm,
+ follower_type=follower_type,
+ name=f"{follower_type}_user",
+)
+```
+
+This setup simplifies the process, we only need to input the follower type, and the system handles the rest. We could update this too to have a random type of follower by default, and simulate from a bunch of different personalities.
+
+## Building our Pipeline
+
+The foundation of our pipeline is now in place. At its core is a single, powerful LLM. This versatile model will be repurposed to drive three distinct `SocialAI` Tasks, each tailored to a specific `TextGeneration` task, and each one of them will be prepared for Supervised Fine Tuning using [`FormatTextGenerationSFT`](https://distilabel.argilla.io/dev/components-gallery/steps/formattextgenerationsft/):
+
+```python
+with Pipeline(name="Social AI Personas") as pipeline:
+ loader = LoadDataFromDicts(data=data, batch_size=1)
+
+ llm = InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 256,
+ },
+ )
+
+ for follower_type in ["supporter", "troll", "alarmist"]:
+ follower = SocialAI(
+ llm=llm,
+ follower_type=follower_type,
+ name=f"{follower_type}_user", # (1)
+ output_mappings={
+ "generation": f"interaction_{follower_type}" # (2)
+ }
+ )
+ format_sft = FormatTextGenerationSFT(
+ name=f"format_sft_{follower_type}",
+ input_mappings={
+ "instruction": "post",
+ "generation": f"interaction_{follower_type}" # (3)
+ },
+ )
+ loader >> follower >> format_sft # (4)
+```
+
+1. We update the name of the step to keep track in the pipeline.
+
+2. The `generation` column from each LLM will be mapped to avoid them being overriden, as we are reusing the same task.
+
+3. As we have modified the output column from `SocialAI`, we redirect each one of the "follower_type" responses.
+
+4. Connect the loader to each one of the follower tasks and `format_sft` to obtain 3 different subsets.
+
+The outcome of this pipeline will be three specialized models, each fine-tuned to a unique `follower type` crafted by the `SocialAI` task. These models will generate SFT-formatted datasets, where each post is paired with its corresponding interaction data for a specific follower type. This setup enables seamless fine-tuning using your preferred framework, such as [TRL](https://huggingface.co/docs/trl/index), or any other training framework of your choice.
+
+## Script and final dataset
+
+All the pieces are in place for our script, the full pipeline can be seen here:
+
+??? Run
+
+ ```python
+ python examples/finepersonas_social_ai.py
+ ```
+
+```python title="finepersonas_social_ai.py"
+--8<-- "examples/finepersonas_social_ai.py"
+```
+
+This is the final toy dataset we obtain: [FinePersonas-SocialAI-test](https://huggingface.co/datasets/plaguss/FinePersonas-SocialAI-test)
+
+You can see examples of how to load each subset of them to fine-tune a model:
+
+```python
+from datasets import load_dataset
+
+ds = load_dataset("plaguss/FinePersonas-SocialAI-test", "format_sft_troll")
+```
+
+And a sample of the generated field with the corresponding `post` and `persona`:
+
+```json
+{
+ "post": "Hmm, ok now I\u0027m torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?",
+ "persona": "A high school or undergraduate physics or chemistry teacher, likely with a focus on experimental instruction.",
+ "interaction_troll": "\"Late night cravings? More like late night brain drain. Either way, it\u0027s just a collision of molecules in your stomach. Choose the one with more calories, at least that\u0027s some decent kinetic energy.\"",
+}
+```
+
+There's a lot of room for improvement, but quite a promising start.
diff --git a/docs/sections/pipeline_samples/examples/index.md b/docs/sections/pipeline_samples/examples/index.md
deleted file mode 100644
index 68e25fc888..0000000000
--- a/docs/sections/pipeline_samples/examples/index.md
+++ /dev/null
@@ -1,78 +0,0 @@
-# Examples
-
-This section contains different example pipelines that showcase different tasks, maybe you can take inspiration from them.
-
-### [llama.cpp with `outlines`](#llamacpp-with-outlines)
-
-Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `distilabel`.
-
-??? Example "See example"
-
- This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.
-
- It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].
-
- ??? Run
-
- ```python
- python examples/structured_generation_with_outlines.py
- ```
-
- ```python title="structured_generation_with_outlines.py"
- --8<-- "examples/structured_generation_with_outlines.py"
- ```
-
-
-### [MistralAI with `instructor`](#mistralai-with-instructor)
-
-Answer instructions with knowledge graphs defined as `pydantic.BaseModel` objects using `instructor` in `distilabel`.
-
-??? Example "See example"
-
- This script makes use of [`MistralLLM`][distilabel.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics.
-
- This example is translated from this [awesome example](https://python.useinstructor.com/examples/knowledge_graph/) from `instructor` cookbook.
-
- ??? Run
-
- ```python
- python examples/structured_generation_with_instructor.py
- ```
-
- ```python title="structured_generation_with_instructor.py"
- --8<-- "examples/structured_generation_with_instructor.py"
- ```
-
- ??? "Visualizing the graphs"
-
- Want to see how to visualize the graphs? You can test it using the following script. Generate some samples on your own and take a look:
-
- !!! NOTE
-
- This example uses graphviz to render the graph, you can install with `pip` in the following way:
-
- ```console
- pip install graphviz
- ```
-
- ```python
- python examples/draw_kg.py 2 # You can pass 0,1,2 to visualize each of the samples.
- ```
-
- ![Knowledge graph figure](../../../assets/images/sections/examples/knowledge-graph-example.png)
-
-
-### [Benchmarking with `distilabel`: Arena Hard](#benchmarking-with-distilabel-arena-hard)
-
-Benchmark LLMs with `distilabel`: reproducing the Arena Hard benchmark.
-
-??? Example "See example"
-
- The script below first defines both the `ArenaHard` and the `ArenaHardResults` tasks, so as to generate responses for a given collection of prompts/questions with up to two LLMs, and then calculate the results as per the original implementation, respectively. Additionally, the second part of the example builds a `Pipeline` to run the generation on top of the prompts with `InferenceEndpointsLLM` while streaming the rest of the generations from a pre-computed set of GPT-4 generations, and then evaluate one against the other with `OpenAILLM` generating an alternate response, a comparison between the responses, and a result as A>>B, A>B, B>A, B>>A, or tie.
-
- To run this example you will first need to install the Arena Hard optional dependencies, being `pandas`, `scikit-learn`, and `numpy`.
-
- ```python title="arena_hard.py"
- --8<-- "examples/arena_hard.py"
- ```
-
diff --git a/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md b/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md
new file mode 100644
index 0000000000..9ff0bdff8f
--- /dev/null
+++ b/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md
@@ -0,0 +1,22 @@
+---
+hide: toc
+---
+# Structured generation with `outlines`
+
+Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `distilabel`.
+
+This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.
+
+![Arena Hard](../../../assets/pipelines/knowledge_graphs.png)
+
+It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].
+
+??? Run
+
+ ```python
+ python examples/structured_generation_with_outlines.py
+ ```
+
+```python title="structured_generation_with_outlines.py"
+--8<-- "examples/structured_generation_with_outlines.py"
+```
\ No newline at end of file
diff --git a/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md b/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md
new file mode 100644
index 0000000000..7e081ab222
--- /dev/null
+++ b/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md
@@ -0,0 +1,40 @@
+---
+hide: toc
+---
+# Structured generation with `instructor`
+
+Answer instructions with knowledge graphs defined as `pydantic.BaseModel` objects using `instructor` in `distilabel`.
+
+This script makes use of [`MistralLLM`][distilabel.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics.
+
+![Knowledge graph figure](../../../assets/pipelines/knowledge_graphs.png)
+
+This example is translated from this [awesome example](https://python.useinstructor.com/examples/knowledge_graph/) from `instructor` cookbook.
+
+??? Run
+
+ ```python
+ python examples/structured_generation_with_instructor.py
+ ```
+
+```python title="structured_generation_with_instructor.py"
+--8<-- "examples/structured_generation_with_instructor.py"
+```
+
+??? "Visualizing the graphs"
+
+ Want to see how to visualize the graphs? You can test it using the following script. Generate some samples on your own and take a look:
+
+ !!! NOTE
+
+ This example uses graphviz to render the graph, you can install with `pip` in the following way:
+
+ ```console
+ pip install graphviz
+ ```
+
+ ```python
+ python examples/draw_kg.py 2 # You can pass 0,1,2 to visualize each of the samples.
+ ```
+
+ ![Knowledge graph figure](../../../assets/images/sections/examples/knowledge-graph-example.png)
\ No newline at end of file
diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md
new file mode 100644
index 0000000000..6cf718faab
--- /dev/null
+++ b/docs/sections/pipeline_samples/index.md
@@ -0,0 +1,146 @@
+---
+hide: toc
+---
+# Tutorials
+
+- **End-to-end tutorials** provide detailed step-by-step explanations and the code used for end-to-end workflows.
+- **Paper implementations** provide reproductions of fundamental papers in the synthetic data domain.
+- **Examples** don't provide explenations but simply show code for different tasks.
+
+## End-to-end tutorials
+
+
+
+- __Generate a preference dataset__
+
+ ---
+
+ Learn about synthetic data generation for ORPO and DPO.
+
+ [:octicons-arrow-right-24: Tutorial](tutorials/generate_preference_dataset.ipynb)
+
+
+- __Clean an existing preference dataset__
+
+ ---
+
+ Learn about how to provide AI feedback to clean an existing dataset.
+
+ [:octicons-arrow-right-24: Tutorial](tutorials/clean_existing_dataset.ipynb)
+
+
+- __Retrieval and reranking models__
+
+ ---
+
+ Learn about synthetic data generation for fine-tuning custom retrieval and reranking models.
+
+ [:octicons-arrow-right-24: Tutorial](tutorials/GenerateSentencePair.ipynb)
+
+
+
+## Paper Implementations
+
+
+
+- __Deepseek Prover__
+
+ ---
+
+ Learn about an approach to generate mathematical proofs for theorems generated from informal math problems.
+
+ [:octicons-arrow-right-24: Example](papers/deepseek_prover.md)
+
+- __DEITA__
+
+ ---
+
+ Learn about prompt, response tuning for complexity and quality and LLMs as judges for automatic data selection.
+
+ [:octicons-arrow-right-24: Paper](papers/deita.md)
+
+- __Instruction Backtranslation__
+
+ ---
+
+ Learn about automatically labeling human-written text with corresponding instructions.
+
+ [:octicons-arrow-right-24: Paper](papers/instruction_backtranslation.md)
+
+- __Prometheus 2__
+
+ ---
+
+ Learn about using open-source models as judges for direct assessment and pair-wise ranking.
+
+ [:octicons-arrow-right-24: Paper](papers/prometheus.md)
+
+- __UltraFeedback__
+
+ ---
+
+ Learn about a large-scale, fine-grained, diverse preference dataset, used for training powerful reward and critic models.
+
+ [:octicons-arrow-right-24: Paper](papers/ultrafeedback.md)
+
+- __APIGen__
+
+ ---
+
+ Learn how to create verifiable high-quality datases for function-calling applications.
+
+ [:octicons-arrow-right-24: Paper](papers/apigen.md)
+
+- __CLAIR__
+
+ ---
+
+ Learn Contrastive Learning from AI Revisions (CLAIR), a data-creation method which leads to more contrastive preference pairs.
+
+ [:octicons-arrow-right-24: Paper](papers/clair.md)
+
+
+
+## Examples
+
+
+
+- __Benchmarking with distilabel__
+
+ ---
+
+ Learn about reproducing the Arena Hard benchmark with disitlabel.
+
+ [:octicons-arrow-right-24: Example](examples/benchmarking_with_distilabel.md)
+
+- __Structured generation with outlines__
+
+ ---
+
+ Learn about generating RPG characters following a pydantic.BaseModel with outlines in distilabel.
+
+ [:octicons-arrow-right-24: Example](examples/llama_cpp_with_outlines.md)
+
+- __Structured generation with instructor__
+
+ ---
+
+ Learn about answering instructions with knowledge graphs defined as pydantic.BaseModel objects using instructor in distilabel.
+
+ [:octicons-arrow-right-24: Example](examples/mistralai_with_instructor.md)
+
+- __Create a social network with FinePersonas__
+
+ ---
+
+ Learn how to leverage FinePersonas to create a synthetic social network and fine-tune adapters for Multi-LoRA.
+
+ [:octicons-arrow-right-24: Example](examples/fine_personas_social_network.md)
+
+
+
+
+
+
+
+
diff --git a/docs/sections/pipeline_samples/papers/apigen.md b/docs/sections/pipeline_samples/papers/apigen.md
new file mode 100644
index 0000000000..5d3522c1b7
--- /dev/null
+++ b/docs/sections/pipeline_samples/papers/apigen.md
@@ -0,0 +1,239 @@
+---
+hide: toc
+---
+
+# Create Function-Calling datasets with APIGen
+
+This example will introduce [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518), a data generation pipeline designed to synthesize verifiable high-quality datasets for function-calling applications.
+
+## Replication
+
+The following figure showcases the APIGen framework:
+
+![APIGen framework](../../../assets/tutorials-assets/overview-apigen.jpg)
+
+Now, let's walk through the key steps illustrated in the figure:
+
+- [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/step/datasampler/): With the help of this step and the original [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) we are getting the Seed QA Data Sampler for the prompt template.
+
+- [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/): This step does the job of the *Query-Answer Generator*, including the format checker from *Stage 1: Format Checker* thanks to the structured output generation.
+
+- [`APIGenExecutionChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigenexecutionchecker/): This step is in charge of the *Stage 2: Execution Checker*.
+
+- [`APIGenSemanticChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigensemanticchecker/): Step in charge of running *Stage 3: Semantic Checker*, can use the same or a different LLM, we are using the same as in [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) step.
+
+The current implementation hasn't utilized the *Diverse Prompt Library*. To incorporate it, one could either adjust the prompt template within the [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) or develop a new sampler specifically for this purpose. As for the *API Sampler*, while no specific data is shared here, we've created illustrative examples to demonstrate the pipeline's functionality. These examples represent a mix of data that could be used to replicate the sampler's output.
+
+## Data preparation
+
+The original paper tells about the data they used and give some hints, but nothing was shared. In this example, we will write a bunch of examples by hand to showcase how this pipeline can be built.
+
+Assume we have the following function names, and corresponding descriptions of their behaviour:
+
+```python
+data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ {
+ "func_name": "binary_addition",
+ "func_desc": "Adds two binary numbers and returns the result as a binary string.",
+ },
+ {
+ "func_name": "swapi_planet_resource",
+ "func_desc": "get a specific planets resource",
+ },
+ {
+ "func_name": "disney_character",
+ "func_desc": "Find a specific character using this endpoint",
+ }
+]
+```
+
+The original paper refers to both python functions and APIs, but we will make use of python functions exclusively for simplicity. In order to execute and check this functions/APIs, we need access to the code, which we have moved to a Python file: [lib_apigen.py](https://github.com/argilla-io/distilabel/blob/main/examples/lib_apigen.py). All this functions are executable, but we also need access to their *tool* representation. For this, we will make use of transformers' *get_json_schema* function[^1].
+
+[^1]: Read this nice blog post for more information on tools and the reasoning behind `get_json_schema`: [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
+
+We have all the machinery prepared in our libpath, except from the *tool* definition. With the help of our helper function `load_module_from_path` we will load this python module, collect all the tools, and add them to each row in our `data` variable.
+
+```python
+from distilabel.steps.tasks.apigen.utils import load_module_from_path
+
+libpath_module = load_module_from_path(libpath)
+tools = getattr(libpath_module, "get_tools")() # call get_tools()
+
+for row in data:
+ # The tools should have a mix where both the correct and irrelevant tools are present.
+ row.update({"tools": [tools[row["func_name"]]]})
+```
+
+Now we have all the necessary data for our prompt. Additionally, we will make use of the original dataset as few-shot examples to enhance the model:
+
+```python
+ds_og = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+)
+```
+
+We have just loaded a subset and transformed it to a list of dictionaries, as we will use it in the [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/steps/datasampler/) `GeneratorStep`, grabbing random examples from the original dataset.
+
+## Building the Pipeline
+
+Now that we've walked through each component, it's time to see how it all comes together, here's the Pipeline code:
+
+```python
+with Pipeline(name="apigen-example") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data) # (1)
+
+ sampler = DataSampler( # (2)
+ data=ds_og,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+
+ prep_examples = PrepareExamples() # This step will add the 'examples' column
+
+ combine_steps = CombineOutputs() # (3)
+
+ model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
+ llm=InferenceEndpointsLLM( # (4)
+ model_id=model_id,
+ tokenizer_id=model_id,
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 2048,
+ },
+ )
+ apigen = APIGenGenerator( # (5)
+ llm=llm,
+ use_default_structured_output=True,
+ )
+
+ execution_checker = APIGenExecutionChecker(libpath=str(libpath)) # (6)
+ semantic_checker = APIGenSemanticChecker(llm=llm) # (7)
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ >> apigen
+ >> execution_checker
+ >> semantic_checker
+ )
+```
+
+1. Load the data seeds we are going to use to generate our function calling dataset.
+
+2. The `DataSampler` together with `PrepareExamples` will be used to help us create the few-shot
+examples from the original dataset to be fed in our prompt.
+
+3. Combine both columns to obtain a single stream of data
+
+4. Will reuse the same LLM for the generation and the semantic checks.
+
+5. Creates the `query` and `answers` that will be used together with the `tools` to fine-tune a new model. Will generate the structured outputs to ensure we have valid JSON formatted answers.
+
+6. Adds columns `keep_row_after_execution_check` and `execution_result`.
+
+7. Adds columns `keep_row_after_semantic_check` and `thought`.
+
+## Script and final dataset
+
+To see all the pieces in place, take a look at the full pipeline, as well as an example row that would be generated from this pipeline.
+
+??? Run
+
+ ```python
+ python examples/pipeline_apigen.py
+ ```
+
+```python title="pipeline_apigen.py"
+--8<-- "examples/pipeline_apigen.py"
+```
+
+Example row:
+
+```json
+{
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ "tools": [
+ {
+ "function": {
+ "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ "name": "final_velocity",
+ "parameters": {
+ "properties": {
+ "acceleration": {
+ "description": "The acceleration of the object.",
+ "type": "number"
+ },
+ "initial_velocity": {
+ "description": "The initial velocity of the object.",
+ "type": "number"
+ },
+ "time": {
+ "description": "The time elapsed.",
+ "type": "number"
+ }
+ },
+ "required": [
+ "initial_velocity",
+ "acceleration",
+ "time"
+ ],
+ "type": "object"
+ }
+ },
+ "type": "function"
+ }
+ ],
+ "examples": "## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]",
+ "query": "What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.",
+ "answers": "[{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]",
+ "distilabel_metadata": {
+ "raw_input_a_p_i_gen_generator_0": [
+ {
+ "content": "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively",
+ "role": "system"
+ },
+ {
+ "content": "Here are examples of queries and the corresponding answers for similar functions:\n## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\n\nBased on these examples, generate 1 diverse query and answer pairs for the function `final_velocity`.\nThe detailed function description is the following:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n\nThese are the available tools to help you:\n[{'type': 'function', 'function': {'name': 'final_velocity', 'description': 'Calculates the final velocity of an object given its initial velocity, acceleration, and time.', 'parameters': {'type': 'object', 'properties': {'initial_velocity': {'type': 'number', 'description': 'The initial velocity of the object.'}, 'acceleration': {'type': 'number', 'description': 'The acceleration of the object.'}, 'time': {'type': 'number', 'description': 'The time elapsed.'}}, 'required': ['initial_velocity', 'acceleration', 'time']}}}]\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n \"query\": \"The generated query.\",\n \"answers\": [\n {\n \"name\": \"api_name\",\n \"arguments\": {\n \"arg_name\": \"value\"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 1 diverse query and answer pairs following the above format.",
+ "role": "user"
+ }
+ ],
+ "raw_input_a_p_i_gen_semantic_checker_0": [
+ {
+ "content": "As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user\u2019s intentions.\n\nDo not pass if:\n1. The function call does not align with the query\u2019s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user\u2019s intentions.\n4. The execution results are irrelevant and do not match the function\u2019s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.",
+ "role": "system"
+ },
+ {
+ "content": "Given Information:\n- All Available Functions:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n- User Query: What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\n- Generated Function Calls: [{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]\n- Execution Results: ['9.8']\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n \"thought\": \"Concisely describe your reasoning here\",\n \"passes\": \"yes\" or \"no\"\n}\n```\n",
+ "role": "user"
+ }
+ ],
+ "raw_output_a_p_i_gen_generator_0": "{\"pairs\": [\n {\n \"answers\": [\n {\n \"arguments\": {\n \"acceleration\": \"9.8\",\n \"initial_velocity\": \"0\",\n \"time\": \"10\"\n },\n \"name\": \"final_velocity\"\n }\n ],\n \"query\": \"What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\"\n }\n]}",
+ "raw_output_a_p_i_gen_semantic_checker_0": "{\n \"thought\": \"\",\n \"passes\": \"yes\"\n}"
+ },
+ "model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct",
+ "keep_row_after_execution_check": true,
+ "execution_result": [
+ "9.8"
+ ],
+ "thought": "",
+ "keep_row_after_semantic_check": true
+}
+```
diff --git a/docs/sections/pipeline_samples/papers/clair.md b/docs/sections/pipeline_samples/papers/clair.md
new file mode 100644
index 0000000000..8c0887460b
--- /dev/null
+++ b/docs/sections/pipeline_samples/papers/clair.md
@@ -0,0 +1,84 @@
+# Contrastive Learning From AI Revisions (CLAIR)
+
+["Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment"](https://huggingface.co/papers/2408.06266) introduces both Contrastive
+Learning from AI Revisions (CLAIR), a data-creation method which leads to more contrastive preference pairs, and Anchored Preference Optimization (APO), a controllable and more stable alignment objective. While APO can be found in [TRL](https://huggingface.co/docs/trl/dpo_trainer#loss-functions), we have implemented a task for CLAIR in `distilabel`.
+
+CLAIR is a method for creating preference pairs which minimally revises one output to express a preference, resulting in a more precise learning signal as opposed to conventional methods which use a judge to select a preferred response.
+
+![CLAIR overview](../../../assets/pipelines/clair.png)
+
+The athors from the original paper shared a [collection of datasets from CLAIR and APO](https://huggingface.co/collections/ContextualAI/clair-and-apo-66b52868672bb1c984d1f3d5), where [ContextualAI/ultrafeedback_clair_32k](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) corresponds to the CLAIR implementation.
+
+### Replication
+
+!!! NOTE
+ The section is named `Replication` but in this case we are showing how to use the [`CLAIR`][distilabel.steps.tasks.clair.CLAIR] task create revisions for your generations using `distilabel`.
+
+To showcase CLAIR we will be using the [`CLAIR`][distilabel.steps.tasks.PrometheusEval] task implemented in `distilabel` and we are reusing a small sample of the already generated dataset by ContextualAI [`ContextualAI/ultrafeedback_clair_32k`](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) for testing.
+
+#### Installation
+
+To reproduce the code below, one will need to install `distilabel` as follows:
+
+```bash
+pip install "distilabel>=1.4.0"
+```
+
+Depending on the LLM provider you want to use, the requirements may vary, take a look at the dependencies in that case, we are using for the example the free inference endpoints from Hugging Face, but that won't apply for a bigger dataset.
+
+#### Building blocks
+
+In this case where we already have instructions and their generations, we will just need to load the data and the corresponding CLAIR task for the revisions:
+
+- [`CLAIR`](https://distilabel.argilla.io/dev/components-gallery/tasks/clair/) to generate the revisions.
+
+#### Code
+
+Let's see the full pipeline applied to `ContextualAI/ultrafeedback_clair_32k` in `distilabel`:
+
+```python
+from typing import Any, Dict
+
+from datasets import load_dataset
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps.tasks import CLAIR
+from distilabel.llms import InferenceEndpointsLLM
+
+
+def transform_ultrafeedback(example: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "task": example["prompt"],
+ "student_solution": example["rejected"][1]["content"],
+ }
+
+dataset = (
+ load_dataset("ContextualAI/ultrafeedback_clair_32k", split="train")
+ .select(range(10)) # We collect just 10 examples
+ .map(transform_ultrafeedback) # Apply the transformation to get just the text
+)
+
+with Pipeline(name="CLAIR UltraFeedback sample") as pipeline:
+ clair = CLAIR( # (1)
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 4096
+ }
+ )
+ )
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run(dataset=dataset) # (2)
+ distiset.push_to_hub(repo_id="username/clair-test", include_script=True) # (3)
+```
+
+1. This Pipeline uses just CLAIR because we already have the generations, but one can just include a first task to create generations from instructions, and then the revisions with CLAIR.
+
+2. Include the dataset directly in the run method for simplicity.
+
+3. Push the distiset to the hub with the script for reproducibility.
+
+An example dataset can be found at: [distilabel-internal-testing/clair-test](https://huggingface.co/datasets/distilabel-internal-testing/clair-test).
diff --git a/docs/sections/pipeline_samples/papers/deepseek_prover.md b/docs/sections/pipeline_samples/papers/deepseek_prover.md
index 29e86c64cf..c7ecfcc32d 100644
--- a/docs/sections/pipeline_samples/papers/deepseek_prover.md
+++ b/docs/sections/pipeline_samples/papers/deepseek_prover.md
@@ -8,6 +8,8 @@ The authors propose a method for generating [Lean 4](https://github.com/leanprov
Here we show how to deal with steps 1 and 2, but the authors ensure the theorems are checked using the [lean4](https://github.com/leanprover/lean4) program on the generated proofs, and iterate for a series of steps, fine-tuning a model on the synthetic data (DeepSeek prover 7B), regenerating the dataset, and continue the process until no further improvement is found.
+![DEITA pipeline overview](../../../assets/pipelines/deepseek.png)
+
### Replication
!!! Note
@@ -32,7 +34,7 @@ There are three components we needed to define for this pipeline, for the differ
!!! Note
We will use the same `LLM` for all the tasks, so we will define once and reuse it for the different tasks:
-
+
```python
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
diff --git a/docs/sections/pipeline_samples/papers/deita.md b/docs/sections/pipeline_samples/papers/deita.md
index 5c9036d756..b9d3e9eea6 100644
--- a/docs/sections/pipeline_samples/papers/deita.md
+++ b/docs/sections/pipeline_samples/papers/deita.md
@@ -1,15 +1,17 @@
# DEITA
-DEITA (Data-Efficient Instruction Tuning for Alignment) studies an automatic data selection process by first quantifying the data quality based on complexity, quality and diversity. And second, selecting across the best potential combination from an open-source dataset that would fit into the budget you allocate to tune your own LLM.
+[DEITA (Data-Efficient Instruction Tuning for Alignment)](https://arxiv.org/abs/2312.15685) studies an automatic data selection process by first quantifying the data quality based on complexity, quality and diversity. Second, select the best potential combination from an open-source dataset that would fit into the budget you allocate to tune your own LLM.
-In most setting we cannot allocate unlimited resources for instruction-tuning LLMs. Therefore, the DEITA authors investigated how to select qualitative data for instruction-tuning based on a principle of fewer high quality samples. Liu et al. tackle the issue of first defining good data and second identifying it to respect an initial budget to instruct-tune your LLM.
+In most setting we cannot allocate unlimited resources for instruction-tuning LLMs. Therefore, the DEITA authors investigated how to select qualitative data for instruction tuning based on the principle of fewer high-quality samples. Liu et al. tackle the issue of first defining good data and second identifying it to respect an initial budget to instruct-tune your LLM.
-The strategy utilizes **LLMs to replace human effort in time-intensive data quality tasks on instruction tuning datasets**. DEITA introduces a way to measure data quality across three critical dimensions: complexity, quality and diversity.
+The strategy utilizes **LLMs to replace human effort in time-intensive data quality **tasks on **instruction-tuning** datasets**. DEITA introduces a way to measure data quality across three critical dimensions: complexity, quality and diversity.
![DEITA pipeline overview](../../../assets/tutorials-assets/deita/overview.png)
You can see that we see again the dataset of instructions/responses and we kind of reproducing the second step when we learn how to optimize the responses according to an instruction by comparing several possibilities.
+![DEITA pipeline overview](../../../assets/pipelines/deita.png)
+
### Datasets and budget
We will dive deeper into the whole process. We will investigate each stage to efficiently select the final dataset used for supervised fine-tuning with a budget constraint. We will tackle technical challenges by explaining exactly how you would assess good data as presented in the paper.
diff --git a/docs/sections/pipeline_samples/papers/index.md b/docs/sections/pipeline_samples/papers/index.md
deleted file mode 100644
index 7fed3da03a..0000000000
--- a/docs/sections/pipeline_samples/papers/index.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Paper Implementations
-
-Contains some implementations for synthetic data generation papers, using `distilabel`, providing reproducible pipelines so that anyone can play around with those approaches and customize that to their needs. We strongly believe that better data leads to better models, and synthetic data has proven to be really effective towards improving LLMs, so we aim to bridge the gap between research and practice by providing these implementations.
diff --git a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
index 8434742984..b3a6b20d68 100644
--- a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
+++ b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
@@ -1,18 +1,20 @@
# Instruction Backtranslation
-["Self Alignment with Instruction Backtranslation"](https://arxiv.org/abs/2308.06259) presents a scalable method to build a high quality instruction following language model by automatically labelling human-written text with corresponding instructions. Their approach, named instruction backtranslation, starts with a language model finetuned on a small amount of seed data, and a given web corpus. The seed model is used to construct training examples by generating instruction prompts for web documents (self-augmentation), and then selecting high quality examples from among these candidates (self-curation). This data is then used to finetune a stronger model.
+["Self Alignment with Instruction Backtranslation"](https://arxiv.org/abs/2308.06259) presents a scalable method to build high-quality instruction following a language model by automatically labeling human-written text with corresponding instructions. Their approach, named instruction backtranslation, starts with a language model finetuned on a small amount of seed data, and a given web corpus. The seed model is used to construct training examples by generating instruction prompts for web documents (self-augmentation), and then selecting high-quality examples from among these candidates (self-curation). This data is then used to finetune a stronger model.
-Their self-training approach assumes access to a base language model, a small amount of seed data, and a collection of unlabelled examples, e.g. a web corpus. The unlabelled data is a large, diverse set of human-written documents which includes writing about all manner of topics humans are interested in – but crucially is not paired with instructions.
+![Instruction Backtranslation pipeline overview](../../../assets/pipelines/instruction_backtranslation.png)
-A first key assumption is that there exists some subset of this very large human-written text that would be suitable as gold generations for some user instructions. A second key assumption is that they can predict instructions for these candidate gold answers that can be used as high quality example pairs to train an instruction following model.
+Their self-training approach assumes access to a base language model, a small amount of seed data, and a collection of unlabelled examples, e.g. a web corpus. The unlabelled data is a large, diverse set of human-written documents that includes writing about all manner of topics humans are interested in – but crucially is not paired with instructions.
-Their overall process, called instruction backtranslation performs two core steps:
+A first key assumption is that there exists some subset of this very large human-written text that would be suitable as gold generations for some user instructions. A second key assumption is that they can predict instructions for these candidate gold answers that can be used as high-quality example pairs to train an instruction-following model.
+
+Their overall process, called instruction back translation performs two core steps:
1. Self-augment: Generate instructions for unlabelled data, i.e. the web corpus, to produce candidate training data of (instruction, output) pairs for instruction tuning.
-2. Self-curate: Self-select high quality demonstration examples as training data to finetune the base model to follow instructions. This approach is done iteratively where a better intermediate instruction-following model can improve on selecting data for finetuning in the next iteration.
+2. Self-curate: Self-select high-quality demonstration examples as training data to finetune the base model to follow instructions. This approach is done iteratively where a better intermediate instruction-following model can improve on selecting data for finetuning in the next iteration.
-This replication covers the self-curation step i.e. the second / latter step as mentioned above, so as to be able to use the proposed prompting approach to rate the quality of the generated text, which can either be synthetically generated or real human-written text.
+This replication covers the self-curation step i.e. the second/latter step as mentioned above, so as to be able to use the proposed prompting approach to rate the quality of the generated text, which can either be synthetically generated or real human-written text.
### Replication
diff --git a/docs/sections/pipeline_samples/papers/prometheus.md b/docs/sections/pipeline_samples/papers/prometheus.md
index 7f7b1d19d5..c8a3fb16c5 100644
--- a/docs/sections/pipeline_samples/papers/prometheus.md
+++ b/docs/sections/pipeline_samples/papers/prometheus.md
@@ -1,20 +1,22 @@
# Prometheus 2
-["Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models"](https://arxiv.org/pdf/2405.01535) presents Prometheus 2, a new and more powerful evaluator LLM compared to Prometheus (its predecessor) presented in ["Prometheus: Inducing Fine-grained Evaluation Capability in Language Models"](https://arxiv.org/abs/2310.08491); since GPT-4, as well as other proprietary LLMs, are commonly used to asses the quality of the responses for various LLMs, but there are concerns about transparency, controllability, and affordability, that motivate the need of open-source LLMs specialized in evaluations.
+["Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models"](https://arxiv.org/pdf/2405.01535) presents Prometheus 2, a new and more powerful evaluator LLM compared to Prometheus (its predecessor) presented in ["Prometheus: Inducing Fine-grained Evaluation Capability in Language Models"](https://arxiv.org/abs/2310.08491); since GPT-4, as well as other proprietary LLMs, are commonly used to assess the quality of the responses for various LLMs, but there are concerns about transparency, controllability, and affordability, that motivate the need of open-source LLMs specialized in evaluations.
+
+![Prometheus 2 pipeline overview](../../../assets/pipelines/prometheus.png)
Existing open evaluator LMs exhibit critical shortcomings:
1. They issue scores that significantly diverge from those assigned by humans.
2. They lack the flexibility to perform both direct assessment and pairwise ranking, the two most prevalent forms of assessment.
-Additionally, they do not possess the ability to evaluate based on custom evaluation criteria, focusing instead on general attributes like helpfulness and harmlessness. Prometheus 2 is capable of processing both direct assessment and pair-wise ranking formats grouped with a user-defined evaluation criteria.
+Additionally, they do not possess the ability to evaluate based on custom evaluation criteria, focusing instead on general attributes like helpfulness and harmlessness. Prometheus 2 is capable of processing both direct assessment and pair-wise ranking formats grouped with user-defined evaluation criteria.
Prometheus 2 released two variants:
- [`prometheus-eval/prometheus-7b-v2.0`](https://hf.co/prometheus-eval/prometheus-7b-v2.0): fine-tuned on top of [`mistralai/Mistral-7B-Instruct-v0.2`](https://hf.co/mistralai/Mistral-7B-Instruct-v0.2)
- [`prometheus-eval/prometheus-8x7b-v2.0`](https://hf.co/prometheus-eval/prometheus-8x7b-v2.0): fine-tuned on top of [`mistralai/Mixtral-8x7B-Instruct-v0.1`](https://hf.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
-Both models have been fine-tuned for both direct assessment and pairwise ranking tasks i.e. assessing the quality of a single isolated response for a given instruction with or without a reference answer, and assessing the quality of one response against another one for a given instruction with or without a reference answer, respectively.
+Both models have been fine-tuned for both direct assessment and pairwise ranking tasks i.e. assessing the quality of a single isolated response for a given instruction with or without a reference answer and assessing the quality of one response against another one for a given instruction with or without a reference answer, respectively.
On four direct assessment benchmarks and four pairwise ranking benchmarks, Prometheus 2 scores the highest correlation and agreement with humans and proprietary LM judges among all tested open evaluator LMs. Their models, code, and data are all publicly available at [`prometheus-eval/prometheus-eval`](https://github.com/prometheus-eval/prometheus-eval).
diff --git a/docs/sections/pipeline_samples/papers/ultrafeedback.md b/docs/sections/pipeline_samples/papers/ultrafeedback.md
index 704309e263..83acc9f335 100644
--- a/docs/sections/pipeline_samples/papers/ultrafeedback.md
+++ b/docs/sections/pipeline_samples/papers/ultrafeedback.md
@@ -4,15 +4,17 @@
UltraFeedback collects about 64k prompts from diverse resources (including UltraChat, ShareGPT, Evol-Instruct, TruthfulQA, FalseQA, and FLAN), then they use these prompts to query multiple LLMs (commercial models, Llama models ranging 7B to 70B, and non-Llama models) and generate four different responses for each prompt, resulting in a total of 256k samples i.e. the UltraFeedback will rate four responses on every OpenAI request.
+![UltraFeedback pipeline overview](../../../assets/pipelines/ultrafeedback.png)
+
To collect high-quality preference and textual feedback, they design a fine-grained annotation instruction, which contains four different aspects, namely instruction-following, truthfulness, honesty and helpfulness (even though within the paper they also mention a fifth one named verbalized calibration). Finally, GPT-4 is used to generate the ratings for the generated responses to the given prompt using the previously mentioned aspects.
-### Replication
+## Replication
To replicate the paper we will be using `distilabel` and a smaller dataset created by the Hugging Face H4 team named [`HuggingFaceH4/instruction-dataset`](https://huggingface.co/datasets/HuggingFaceH4/instruction-dataset) for testing purposes.
Also for testing purposes we will just show how to evaluate the generated responses for a given prompt using a new global aspect named `overall-rating` defined by Argilla, that computes the average of the four aspects, so as to reduce number of requests to be sent to OpenAI, but note that all the aspects are implemented within `distilabel` and can be used instead for a more faithful reproduction. Besides that we will generate three responses for each instruction using three LLMs selected from a pool of six: [`HuggingFaceH4/zephyr-7b-beta`](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta), [`argilla/notus-7b-v1`](https://huggingface.co/argilla/notus-7b-v1), [`google/gemma-1.1-7b-it`](https://huggingface.co/google/gemma-1.1-7b-it), [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), [`HuggingFaceH4/zephyr-7b-gemma-v0.1`](https://huggingface.co/HuggingFaceH4/zephyr-7b-gemma-v0.1) and [`mlabonne/UltraMerge-7B`](https://huggingface.co/mlabonne/UltraMerge-7B).
-#### Installation
+### Installation
To replicate UltraFeedback one will need to install `distilabel` as it follows:
@@ -22,7 +24,7 @@ pip install "distilabel[argilla,openai,vllm]>=1.0.0"
And since we will be using `vllm` we will need to use a VM with at least 6 NVIDIA GPUs with at least 16GB of memory each to run the text generation, and set the `OPENAI_API_KEY` environment variable value.
-#### Building blocks
+### Building blocks
- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: Generator Step to load a dataset from the Hugging Face Hub.
- [`sample_n_steps`][distilabel.pipeline.sample_n_steps]: Function to create a `routing_batch_function` that samples `n` downstream steps for each batch generated by the upstream step. This is the key to replicate the LLM pooling mechanism described in the paper.
@@ -34,7 +36,7 @@ And since we will be using `vllm` we will need to use a VM with at least 6 NVIDI
- [`KeepColumns`][distilabel.steps.KeepColumns]: Task to keep the desired columns while removing the not needed ones, as well as defining the order for those.
- (optional) [`PreferenceToArgilla`][distilabel.steps.PreferenceToArgilla]: Task to optionally push the generated dataset to Argilla to do some further analysis and human annotation.
-#### Code
+### Code
As mentioned before, we will put the previously mentioned building blocks together to replicate UltraFeedback.
diff --git a/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb b/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb
new file mode 100644
index 0000000000..0779a53eb9
--- /dev/null
+++ b/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb
@@ -0,0 +1,633 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Synthetic data generation for fine-tuning custom retrieval and reranking models\n",
+ "\n",
+ "- **Goal**: Bootstrap, optimize and maintain your embedding models and rerankers through synthetic data generation and human feedback.\n",
+ "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub), [sentence-transformers](https://github.com/UKPLab/sentence-transformers)\n",
+ "- **Components**: [LoadDataFromHub](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromhub/), [GenerateSentencePair](https://distilabel.argilla.io/latest/components-gallery/tasks/generatesentencepair/), [InferenceEndpointsLLM](https://distilabel.argilla.io/latest/components-gallery/llms/inferenceendpointsllm/)\n",
+ "\n",
+ "![GenerateSentencePair pipeline overview](../../../assets/pipelines/sentence-transformer.png)\n",
+ "\n",
+ "!!! note\n",
+ " For a comprehensive overview on optimizing the retrieval performance in a RAG pipeline, check this [guide](https://docs.zenml.io/user-guide/llmops-guide/finetuning-embeddings) in collaboration with [ZenML](https://github.com/zenml-io/zenml), an open-source MLOps framework designed for building portable and production-ready machine learning pipelines."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting started\n",
+ "\n",
+ "### Install the dependencies\n",
+ "\n",
+ "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"sentence-transformers~=3.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the needed imports:\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from distilabel.llms.huggingface import InferenceEndpointsLLM\n",
+ "from distilabel.pipeline import Pipeline\n",
+ "from distilabel.steps.tasks import GenerateSentencePair\n",
+ "from distilabel.steps import LoadDataFromHub\n",
+ "\n",
+ "from sentence_transformers import SentenceTransformer, CrossEncoder\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Login to use it directly within this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (optional) Deploy Argilla\n",
+ "\n",
+ "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/). \n",
+ "\n",
+ "Along with that, you will need to install Argilla as a distilabel extra."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the extra needed imports:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argilla as rg"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The dataset\n",
+ "\n",
+ "Before starting any project, it is always important to look at your data. Our data is publicly available [on the Hugging Face Hub](https://huggingface.co/datasets/plaguss/argilla_sdk_docs_raw_unstructured?row=0) so we can have a quick look through [their dataset viewer within an embedded iFrame](https://huggingface.co/docs/hub/datasets-viewer-embed). \n",
+ "\n",
+ "\n",
+ "\n",
+ "As we can see, our dataset contains a column called `chunks`, which was obtained from the Argilla docs. Normally, you would need to download and chunk the data but we will not cover that in this tutorial. To read a full explanation for how this dataset was generated, please refer to [How we leveraged distilabel to create an Argilla 2.0 Chatbot](https://huggingface.co/blog/argilla-chatbot#downloading-and-chunking-data).\n",
+ "\n",
+ "Alternatively, we can load the entire dataset to disk with `datasets.load_dataset`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Synthetic data generation\n",
+ "\n",
+ "The [`GenerateSentencePair`](https://distilabel.argilla.io/latest/components-gallery/tasks/generatesentencepair/) component from `distilabel` can be used to generate training datasets for embeddings models. \n",
+ "\n",
+ "It is a pre-defined `Task` that given an `anchor` sentence generate data for a specific `action`. Supported actions are: `\"paraphrase\", \"semantically-similar\", \"query\", \"answer\"`. In our case the `chunks` column corresponds to the `anchor`. This means we will use `query` to generate potential queries for a fine-tuning a retrieval model and that we will use `semantically-similar` to generate texts that are similar to the intial anchor for fine-tuning a reranking model.\n",
+ "\n",
+ "We will `triplet=True` in order to generate both positive and negative examples, which should help the model generalize better during fine-tuning and we will set `hard_negative=True` to generate more challenging examples that are closer to the anchor and discussed topics.\n",
+ "\n",
+ "Lastly, we can seed the LLM with `context` to generate more relevant examples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "context = (\n",
+ "\"\"\"\n",
+ "The text is a chunk from technical Python SDK documentation of Argilla.\n",
+ "Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets.\n",
+ "Along with prose explanations, the text chunk may include code snippets and Python references.\n",
+ "\"\"\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Retrieval\n",
+ "\n",
+ "For retrieval, we will thus generate queries that are similar to the `chunks` column. We will use the `query` action to generate potential queries for a fine-tuning a retrieval model.\n",
+ "\n",
+ "```python\n",
+ "generate_sentence_pair = GenerateSentencePair(\n",
+ " triplet=True, \n",
+ " hard_negative=True,\n",
+ " action=\"query\",\n",
+ " llm=llm,\n",
+ " input_batch_size=10,\n",
+ " context=context,\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Reranking\n",
+ "\n",
+ "For reranking, we will generate texts that are similar to the intial anchor. We will use the `semantically-similar` action to generate texts that are similar to the intial anchor for fine-tuning a reranking model. In this case, we set `hard_negative=False` to generate more diverse and potentially wrong examples, which can be used as negative examples for similarity fine-tuning because [rerankers cannot be fine-tuned using triplets](https://github.com/UKPLab/sentence-transformers/issues/2366).\n",
+ "\n",
+ "```python\n",
+ "generate_sentence_pair = GenerateSentencePair(\n",
+ " triplet=True,\n",
+ " hard_negative=False,\n",
+ " action=\"semantically-similar\",\n",
+ " llm=llm,\n",
+ " input_batch_size=10,\n",
+ " context=context,\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Combined pipeline\n",
+ "\n",
+ "We will now use the `GenerateSentencePair` task to generate synthetic data for both retrieval and reranking models in a single pipeline. Note that, we map the `chunks` column to the `anchor` argument."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "llm = InferenceEndpointsLLM(\n",
+ " model_id=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
+ " tokenizer_id=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
+ ")\n",
+ "\n",
+ "with Pipeline(name=\"generate\") as pipeline:\n",
+ " load_dataset = LoadDataFromHub(\n",
+ " num_examples=15,\n",
+ " output_mappings={\"chunks\": \"anchor\"},\n",
+ " )\n",
+ " generate_retrieval_pairs = GenerateSentencePair(\n",
+ " name=\"generate_retrieval_pairs\",\n",
+ " triplet=True,\n",
+ " hard_negative=True,\n",
+ " action=\"query\",\n",
+ " llm=llm,\n",
+ " input_batch_size=10,\n",
+ " context=context,\n",
+ " )\n",
+ " generate_reranking_pairs = GenerateSentencePair(\n",
+ " name=\"generate_reranking_pairs\",\n",
+ " triplet=True,\n",
+ " hard_negative=False, # to potentially generate non-relevant pairs\n",
+ " action=\"semantically-similar\",\n",
+ " llm=llm,\n",
+ " input_batch_size=10,\n",
+ " context=context,\n",
+ " )\n",
+ "\n",
+ " load_dataset.connect(generate_retrieval_pairs, generate_reranking_pairs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we can execute this using `pipeline.run`. We will provide some `parameters` to specific components within our pipeline."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generation_kwargs = {\n",
+ " \"llm\": {\n",
+ " \"generation_kwargs\": {\n",
+ " \"temperature\": 0.7,\n",
+ " \"max_new_tokens\": 512,\n",
+ " }\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "distiset = pipeline.run( \n",
+ " parameters={\n",
+ " load_dataset.name: {\n",
+ " \"repo_id\": \"plaguss/argilla_sdk_docs_raw_unstructured\",\n",
+ " \"split\": \"train\",\n",
+ " },\n",
+ " generate_retrieval_pairs.name: generation_kwargs,\n",
+ " generate_reranking_pairs.name: generation_kwargs,\n",
+ " },\n",
+ " use_cache=False, # False for demo\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Data generation can be a expensive, so it is recommended to store the data somewhere. For now, we will store it on the Hugging Face Hub, using our `push_to_hub` method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset.push_to_hub(\"[your-owner-name]/example-retrieval-reranking-dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We have got 2 different leaf/end nodes, therefore we've got a distil configurations we can access, one for the retrieval data, and one for the reranking data.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Looking at these initial examples, we can see they nicely capture the essence of the `chunks` column but we will need to evaluate the quality of the data a bit more before we can use it for fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Data quality evaluation \n",
+ "\n",
+ "Data is never as clean as it can be and this also holds for synthetically generated data too, therefore, it is always good to spent some time and look at your data.\n",
+ "\n",
+ "### Feature engineering\n",
+ "\n",
+ "In order to evaluate the quality of our data we will use features of the models that we intent to fine-tune as proxy for data quality. We can then use these features to filter out the best examples.\n",
+ "\n",
+ "In order to choose a good default model, we will use the [Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard). We want to optimize for size and speed, so we will set model size `<100M` and then filter for `Retrieval` and `Reranking` based on the highest average score, resulting in [Snowflake/snowflake-arctic-embed-s](https://huggingface.co/Snowflake/snowflake-arctic-embed-s) and [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) respectively.\n",
+ "\n",
+ "\n",
+ "\n",
+ "#### Retrieval\n",
+ "\n",
+ "For retrieval, we will compute similarities for the current embeddings of `anchor-positive`, `positive-negative` and `anchor-negative` pairs. We assume that an overlap of these similarities will cause the model to have difficulties generalizing and therefore we can use these features to evaluate the quality of our data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_id = \"Snowflake/snowflake-arctic-embed-m\" # Hugging Face model ID\n",
+ "\n",
+ "model_retrieval = SentenceTransformer(\n",
+ " model_id, device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will encode the generated text pairs and compute the similarities. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.metrics.pairwise import cosine_similarity\n",
+ "\n",
+ "def get_embeddings(texts):\n",
+ " vectors = model_retrieval.encode(texts)\n",
+ " return [vector.tolist() for vector in vectors]\n",
+ "\n",
+ "\n",
+ "def get_similarities(vector_batch_a, vector_batch_b):\n",
+ " similarities = []\n",
+ " for vector_a, vector_b in zip(vector_batch_a, vector_batch_b):\n",
+ " similarity = cosine_similarity([vector_a], [vector_b])[0][0]\n",
+ " similarities.append(similarity)\n",
+ " return similarities\n",
+ "\n",
+ "def format_data_retriever(batch):# -> Any:\n",
+ " batch[\"anchor-vector\"] = get_embeddings(batch[\"anchor\"])\n",
+ " batch[\"positive-vector\"] = get_embeddings(batch[\"positive\"])\n",
+ " batch[\"negative-vector\"] = get_embeddings(batch[\"negative\"]) \n",
+ " batch[\"similarity-positive-negative\"] = get_similarities(batch[\"positive-vector\"], batch[\"negative-vector\"])\n",
+ " batch[\"similarity-anchor-positive\"] = get_similarities(batch[\"anchor-vector\"], batch[\"positive-vector\"])\n",
+ " batch[\"similarity-anchor-negative\"] = get_similarities(batch[\"anchor-vector\"], batch[\"negative-vector\"])\n",
+ " return batch\n",
+ "\n",
+ "dataset_generate_retrieval_pairs = distiset[\"generate_retrieval_pairs\"][\"train\"].map(format_data_retriever, batched=True, batch_size=250)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "#### Reranking\n",
+ "\n",
+ "For reranking, we will compute the compute the relevance scores from an existing reranker model for `anchor-positive`, `positive-negative` and `anchor-negative` pais and make a similar assumption as for the retrieval model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_id = \"sentence-transformers/all-MiniLM-L12-v2\"\n",
+ "\n",
+ "model = CrossEncoder(model_id)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will compute the similarity for the generated text pairs using the reranker. On top of that, we will compute an `anchor-vector` to allow for doing semantic search."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def format_data_retriever(batch):# -> Any:\n",
+ " batch[\"anchor-vector\"] = get_embeddings(batch[\"anchor\"])\n",
+ " batch[\"similarity-positive-negative\"] = model.predict(zip(batch[\"positive-vector\"], batch[\"negative-vector\"]))\n",
+ " batch[\"similarity-anchor-positive\"] = model.predict(zip(batch[\"anchor-vector\"], batch[\"positive-vector\"]))\n",
+ " batch[\"similarity-anchor-negative\"] = model.predict(zip(batch[\"anchor-vector\"], batch[\"negative-vector\"]))\n",
+ " return batch\n",
+ "\n",
+ "dataset_generate_reranking_pairs = distiset[\"generate_reranking_pairs\"][\"train\"].map(format_data_retriever, batched=True, batch_size=250)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And voila, we have our proxies for quality evaluation which we can use to filter out the best and worst examples."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (Optional) Argilla\n",
+ "\n",
+ "To get the most out of you data and actually look at our data, we will use Argilla. If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstart docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).\n",
+ "\n",
+ "To start exploring data, we first need to define an `argilla.Dataset`. We will create a basic datset with some input `TextFields` for the `anchor` and output `TextQuestions` for the `positive` and `negative` pairs. Additionally, we will use the `file_name` as `MetaDataProperty`. Lastly, we will be re-using the vectors obtained from our previous step to allow for semantic search and we will add te similarity scores for some basic filtering and sorting."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First, we need to define the setting for our Argilla dataset. We will create two different datasets, one for the retrieval data and one for the reranking data to ensure our annotators can focus on the task at hand."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argilla as rg\n",
+ "from argilla._exceptions import ConflictError\n",
+ "\n",
+ "api_key = \"ohh so secret\"\n",
+ "api_url = \"https://[your-owner-name]-[your-space-name].hf.space\"\n",
+ "\n",
+ "client = rg.Argilla(api_url=api_url, api_key=api_key)\n",
+ "\n",
+ "settings = rg.Settings(\n",
+ " fields=[\n",
+ " rg.TextField(\"anchor\")\n",
+ " ],\n",
+ " questions=[\n",
+ " rg.TextQuestion(\"positive\"),\n",
+ " rg.TextQuestion(\"negative\"),\n",
+ " rg.LabelQuestion(\n",
+ " name=\"is_positive_relevant\",\n",
+ " title=\"Is the positive query relevant?\",\n",
+ " labels=[\"yes\", \"no\"],\n",
+ " ),\n",
+ " rg.LabelQuestion(\n",
+ " name=\"is_negative_irrelevant\",\n",
+ " title=\"Is the negative query irrelevant?\",\n",
+ " labels=[\"yes\", \"no\"],\n",
+ " )\n",
+ " ],\n",
+ " metadata=[\n",
+ " rg.TermsMetadataProperty(\"filename\"),\n",
+ " rg.FloatMetadataProperty(\"similarity-positive-negative\"),\n",
+ " rg.FloatMetadataProperty(\"similarity-anchor-positive\"),\n",
+ " rg.FloatMetadataProperty(\"similarity-anchor-negative\"),\n",
+ " ],\n",
+ " vectors=[\n",
+ " rg.VectorField(\"anchor-vector\", dimensions=model.get_sentence_embedding_dimension())\n",
+ " ]\n",
+ ")\n",
+ "rg_datasets = []\n",
+ "for dataset_name in [\"generate_retrieval_pairs\", \"generate_reranking_pairs\"]:\n",
+ " ds = rg.Dataset(\n",
+ " name=dataset_name,\n",
+ " settings=settings\n",
+ " )\n",
+ " try:\n",
+ " ds.create()\n",
+ " except ConflictError:\n",
+ " ds = client.datasets(dataset_name)\n",
+ " rg_datasets.append(ds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we've got our dataset definitions setup in Argilla, we can upload our data to Argilla."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ds_datasets = [dataset_generate_retrieval_pairs, dataset_generate_reranking_pairs]\n",
+ "\n",
+ "records = []\n",
+ "\n",
+ "for rg_dataset, ds_dataset in zip(rg_datasets, ds_datasets):\n",
+ " for idx, entry in enumerate(ds_dataset):\n",
+ " records.append(\n",
+ " rg.Record(\n",
+ " id=idx,\n",
+ " fields={\"anchor\": entry[\"anchor\"]},\n",
+ " suggestions=[\n",
+ " rg.Suggestion(\"positive\", value=entry[\"positive\"], agent=\"gpt-4o\", type=\"model\"),\n",
+ " rg.Suggestion(\"negative\", value=entry[\"negative\"], agent=\"gpt-4o\", type=\"model\"),\n",
+ " ],\n",
+ " metadata={\n",
+ " \"filename\": entry[\"filename\"],\n",
+ " \"similarity-positive-negative\": entry[\"similarity-positive-negative\"],\n",
+ " \"similarity-anchor-positive\": entry[\"similarity-anchor-positive\"],\n",
+ " \"similarity-anchor-negative\": entry[\"similarity-anchor-negative\"]\n",
+ " },\n",
+ " vectors={\"anchor-vector\": entry[\"anchor-vector\"]}\n",
+ " )\n",
+ " )\n",
+ " rg_dataset.records.log(records)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can explore the UI and add a final human touch to get he most out of our dataset. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Fine-tuning\n",
+ "\n",
+ "At last, we can fine-tune our models. We will use the `sentence-transformers` library to fine-tune our models.\n",
+ "\n",
+ "### Retrieval\n",
+ "\n",
+ "For retrieval, we have created a script that fine-tunes a model on our generated data the generated data based [https://github.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb](https://github.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb).You can also [open it in Google Colab directly](https://githubtocolab.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Reranking\n",
+ "\n",
+ "For reranking, `sentence-transformers` provides a script that shows [how to fine-tune a CrossEncoder models](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/cross-encoder). Ad of now, there is [some uncertainty over fine-tuning CrossEncoder models with triplets](https://github.com/UKPLab/sentence-transformers/issues/2366) but you can still use the `positive` and `anchor`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions\n",
+ "\n",
+ "In this tutorial, we present an end-to-end example of fine-tuning retrievers and rerankers for RAG. This serves as a good starting point for optimizing and maintaining your data and model but need to be adapted to your specific use case.\n",
+ "\n",
+ "We started with some seed data from the Argilla docs, generated synthetic data for retrieval and reranking models, evaluated the quality of the data, and showed how to fine-tune the models. We also used Argilla to get a human touch on the data."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb
new file mode 100644
index 0000000000..de1e9fd264
--- /dev/null
+++ b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb
@@ -0,0 +1,596 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Clean an existing preference dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **Goal**: Clean an existing preference dataset by providing AI feedback on the quality of the data.\n",
+ "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub)\n",
+ "- **Components**: [LoadDataFromDicts](https://distilabel.argilla.io/dev/components-gallery/steps/loaddatafromdicts/), [UltraFeedback](https://distilabel.argilla.io/latest/components-gallery/tasks/ultrafeedback/), [KeepColumns](https://distilabel.argilla.io/latest/components-gallery/steps/groupcolumns/), [PreferenceToArgilla](https://distilabel.argilla.io/latest/components-gallery/steps/textgenerationtoargilla/), [InferenceEndpointsLLM](https://distilabel.argilla.io/latest/components-gallery/llms/inferenceendpointsllm/), [GlobalStep](../../how_to_guides/basic/step/global_step.md)\n",
+ "\n",
+ "![Knowledge graph figure](../../../assets/pipelines/clean-dataset.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting Started"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install the dependencies\n",
+ "\n",
+ "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"transformers~=4.0\" \"torch~=2.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the required imports:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "from distilabel.llms import InferenceEndpointsLLM\n",
+ "from distilabel.pipeline import Pipeline\n",
+ "from distilabel.steps import (\n",
+ " KeepColumns,\n",
+ " LoadDataFromDicts,\n",
+ " PreferenceToArgilla,\n",
+ ")\n",
+ "from distilabel.steps.tasks import UltraFeedback"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Login to use it directly within this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (optional) Deploy Argilla\n",
+ "\n",
+ "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/). \n",
+ "\n",
+ "Along with that, you will need to install Argilla as a distilabel extra."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this case, we will clean a preference dataset, so we will use the [`Intel/orca_dpo_pairs`](https://huggingface.co/datasets/Intel/orca_dpo_pairs) dataset from the Hugging Face Hub."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = load_dataset(\"Intel/orca_dpo_pairs\", split=\"train[:20]\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will shuffle the `chosen` and `rejected` columns to avoid any bias in the dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def shuffle_and_track(chosen, rejected):\n",
+ " pair = [chosen, rejected]\n",
+ " random.shuffle(pair)\n",
+ " order = [\"chosen\" if x == chosen else \"rejected\" for x in pair]\n",
+ " return {\"generations\": pair, \"order\": order}\n",
+ "\n",
+ "dataset = dataset.map(lambda x: shuffle_and_track(x[\"chosen\"], x[\"rejected\"]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = dataset.to_list()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "??? tip \"As a custom step\"\n",
+ " You can also [create a custom step](../../how_to_guides/basic/step/global_step.md) in a separate module, import it and add it to the pipeline after loading the `orca_dpo_pairs` dataset using the `LoadDataFromHub` step.\n",
+ "\n",
+ " ```python title=\"shuffle_step.py\"\n",
+ " from typing import TYPE_CHECKING, List\n",
+ " from distilabel.steps import GlobalStep, StepInput\n",
+ "\n",
+ " if TYPE_CHECKING:\n",
+ " from distilabel.steps.typing import StepOutput\n",
+ " \n",
+ " import random\n",
+ "\n",
+ " class ShuffleStep(GlobalStep):\n",
+ " @property\n",
+ " def inputs(self):\n",
+ " \"\"\"Returns List[str]: The inputs of the step.\"\"\"\n",
+ " return [\"instruction\", \"chosen\", \"rejected\"]\n",
+ "\n",
+ " @property\n",
+ " def outputs(self):\n",
+ " \"\"\"Returns List[str]: The outputs of the step.\"\"\"\n",
+ " return [\"instruction\", \"generations\", \"order\"]\n",
+ "\n",
+ " def process(self, inputs: StepInput):\n",
+ " \"\"\"Returns StepOutput: The outputs of the step.\"\"\"\n",
+ " outputs = []\n",
+ "\n",
+ " for input in inputs:\n",
+ " chosen = input[\"chosen\"]\n",
+ " rejected = input[\"rejected\"]\n",
+ " pair = [chosen, rejected]\n",
+ " random.shuffle(pair)\n",
+ " order = [\"chosen\" if x == chosen else \"rejected\" for x in pair]\n",
+ " \n",
+ " outputs.append({\"instruction\": input[\"instruction\"], \"generations\": pair, \"order\": order})\n",
+ "\n",
+ " yield outputs\n",
+ " ```\n",
+ " \n",
+ " ```python\n",
+ " from shuffle_step import ShuffleStep\n",
+ " ```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To clean an existing preference dataset, we will need to define a `Pipeline` with all the necessary steps. However, a similar workflow can be used to clean a SFT dataset. Below, we will go over each step in detail."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load the dataset\n",
+ "We will use the dataset we just shuffled as source data.\n",
+ "\n",
+ "- Component: `LoadDataFromDicts`\n",
+ "- Input columns: `system`, `question`, `chosen`, `rejected`, `generations` and `order`, the same keys as in the loaded list of dictionaries.\n",
+ "- Output columns: `system`, `instruction`, `chosen`, `rejected`, `generations` and `order`. We will use `output_mappings` to rename the columns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "([{'system': '',\n",
+ " 'question': \"You will be given a definition of a task first, then some input of the task.\\nThis task is about using the specified sentence and converting the sentence to Resource Description Framework (RDF) triplets of the form (subject, predicate object). The RDF triplets generated must be such that the triplets accurately capture the structure and semantics of the input sentence. The input is a sentence and the output is a list of triplets of the form [subject, predicate, object] that capture the relationships present in the sentence. When a sentence has more than 1 RDF triplet possible, the output must contain all of them.\\n\\nAFC Ajax (amateurs)'s ground is Sportpark De Toekomst where Ajax Youth Academy also play.\\nOutput:\",\n",
+ " 'chosen': '[\\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\\n]',\n",
+ " 'rejected': \" Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\\n\\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\\n\\nExplanation:\\n\\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\\n\\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence.\",\n",
+ " 'generations': [\" Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\\n\\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\\n\\nExplanation:\\n\\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\\n\\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence.\",\n",
+ " '[\\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\\n]'],\n",
+ " 'order': ['rejected', 'chosen']}],\n",
+ " True)"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "load_dataset = LoadDataFromDicts(\n",
+ " data=dataset[:1],\n",
+ " output_mappings={\"question\": \"instruction\"},\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ ")\n",
+ "load_dataset.load()\n",
+ "next(load_dataset.process())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluate the responses\n",
+ "\n",
+ "To evaluate the quality of the responses, we will use [`meta-llama/Meta-Llama-3.1-70B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct), applying the `UltraFeedback` task that judges the responses according to different dimensions (helpfulness, honesty, instruction-following, truthfulness). For an SFT dataset, you can use [`PrometheusEval`](../papers/prometheus.md) instead.\n",
+ "\n",
+ "- Component: `UltraFeedback` task with LLMs using `InferenceEndpointsLLM`\n",
+ "- Input columns: `instruction`, `generations`\n",
+ "- Output columns: `ratings`, `rationales`, `distilabel_metadata`, `model_name`\n",
+ "\n",
+ "For your use case and to improve the results, you can use any [other LLM of your choice](https://distilabel.argilla.io/latest/components-gallery/llms/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'instruction': \"What's the capital of Spain?\",\n",
+ " 'generations': ['Madrid', 'Barcelona'],\n",
+ " 'ratings': [5, 1],\n",
+ " 'rationales': [\"The answer is correct, directly addressing the question, and is free of hallucinations or unnecessary details. It confidently provides the accurate information, aligning perfectly with the user's intent.\",\n",
+ " \"The answer is incorrect as Barcelona is not the capital of Spain. This introduces a significant inaccuracy, failing to provide helpful information and deviating entirely from the user's intent.\"],\n",
+ " 'distilabel_metadata': {'raw_output_ultra_feedback_0': \"#### Output for Text 1\\nRating: 5 (Excellent)\\nRationale: The answer is correct, directly addressing the question, and is free of hallucinations or unnecessary details. It confidently provides the accurate information, aligning perfectly with the user's intent.\\n\\n#### Output for Text 2\\nRating: 1 (Low Quality)\\nRationale: The answer is incorrect as Barcelona is not the capital of Spain. This introduces a significant inaccuracy, failing to provide helpful information and deviating entirely from the user's intent.\"},\n",
+ " 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "evaluate_responses = UltraFeedback(\n",
+ " aspect=\"overall-rating\",\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ ")\n",
+ "evaluate_responses.load()\n",
+ "next(\n",
+ " evaluate_responses.process(\n",
+ " [\n",
+ " {\n",
+ " \"instruction\": \"What's the capital of Spain?\",\n",
+ " \"generations\": [\"Madrid\", \"Barcelona\"],\n",
+ " }\n",
+ " ]\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Keep only the required columns\n",
+ "\n",
+ "We will get rid of the unneeded columns.\n",
+ "\n",
+ "- Component: `KeepColumns`\n",
+ "- Input columns: `system`, `instruction`, `chosen`, `rejected`, `generations`, `ratings`, `rationales`, `distilabel_metadata` and `model_name`\n",
+ "- Output columns: `instruction`, `chosen`, `rejected`, `generations` and `order`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'instruction': \"What's the capital of Spain?\",\n",
+ " 'generations': ['Madrid', 'Barcelona'],\n",
+ " 'order': ['chosen', 'rejected'],\n",
+ " 'ratings': [5, 1],\n",
+ " 'rationales': ['', ''],\n",
+ " 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "keep_columns = KeepColumns(\n",
+ " columns=[\n",
+ " \"instruction\",\n",
+ " \"generations\",\n",
+ " \"order\",\n",
+ " \"ratings\",\n",
+ " \"rationales\",\n",
+ " \"model_name\",\n",
+ " ],\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ ")\n",
+ "keep_columns.load()\n",
+ "next(\n",
+ " keep_columns.process(\n",
+ " [\n",
+ " {\n",
+ " \"system\": \"\",\n",
+ " \"instruction\": \"What's the capital of Spain?\",\n",
+ " \"chosen\": \"Madrid\",\n",
+ " \"rejected\": \"Barcelona\",\n",
+ " \"generations\": [\"Madrid\", \"Barcelona\"],\n",
+ " \"order\": [\"chosen\", \"rejected\"],\n",
+ " \"ratings\": [5, 1],\n",
+ " \"rationales\": [\"\", \"\"],\n",
+ " \"model_name\": \"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n",
+ " }\n",
+ " ]\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (Optional) Further data curation\n",
+ "\n",
+ "You can use Argilla to further curate your data.\n",
+ "\n",
+ "- Component: `PreferenceToArgilla` step\n",
+ "- Input columns: `instruction`, `generations`, `generation_models`, `ratings`\n",
+ "- Output columns: `instruction`, `generations`, `generation_models`, `ratings`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_argilla = PreferenceToArgilla(\n",
+ " dataset_name=\"cleaned-dataset\",\n",
+ " dataset_workspace=\"argilla\",\n",
+ " api_url=\"https://[your-owner-name]-[your-space-name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " num_generations=2\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run the pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below, you can see the full pipeline definition:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with Pipeline(name=\"clean-dataset\") as pipeline:\n",
+ "\n",
+ " load_dataset = LoadDataFromDicts(\n",
+ " data=dataset, output_mappings={\"question\": \"instruction\"}\n",
+ " )\n",
+ "\n",
+ " evaluate_responses = UltraFeedback(\n",
+ " aspect=\"overall-rating\",\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " keep_columns = KeepColumns(\n",
+ " columns=[\n",
+ " \"instruction\",\n",
+ " \"generations\",\n",
+ " \"order\",\n",
+ " \"ratings\",\n",
+ " \"rationales\",\n",
+ " \"model_name\",\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " to_argilla = PreferenceToArgilla(\n",
+ " dataset_name=\"cleaned-dataset\",\n",
+ " dataset_workspace=\"argilla\",\n",
+ " api_url=\"https://[your-owner-name]-[your-space-name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " num_generations=2,\n",
+ " )\n",
+ "\n",
+ " load_dataset.connect(evaluate_responses)\n",
+ " evaluate_responses.connect(keep_columns)\n",
+ " keep_columns.connect(to_argilla)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's now run the pipeline and clean our preference dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset = pipeline.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's check it! If you have loaded the data to Argilla, you can [start annotating in the Argilla UI](https://docs.argilla.io/latest/how_to_guides/annotate/)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can push the dataset to the Hub for sharing with the community and [embed it to explore the data](https://huggingface.co/docs/hub/datasets-viewer-embed)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset.push_to_hub(\"[your-owner-name]/example-cleaned-preference-dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we showcased the detailed steps to build a pipeline for cleaning a preference dataset using distilabel. However, you can customize this pipeline for your own use cases, such as cleaning an SFT dataset or adding custom steps.\n",
+ "\n",
+ "We used a preference dataset as our starting point and shuffled the data to avoid any bias. Next, we evaluated the responses using a model through the serverless Hugging Face Inference API, following the UltraFeedback standards. Finally, we kept the needed columns and used Argilla for further curation."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "distilabel-tutorials",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb
new file mode 100644
index 0000000000..a81e8051ad
--- /dev/null
+++ b/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb
@@ -0,0 +1,596 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Generate a preference dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **Goal**: Generate a synthetic preference dataset for DPO/ORPO.\n",
+ "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub)\n",
+ "- **Components**: [LoadDataFromHub](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromhub/), [TextGeneration](https://distilabel.argilla.io/latest/components-gallery/tasks/textgeneration/), [UltraFeedback](https://distilabel.argilla.io/latest/components-gallery/tasks/ultrafeedback/), [GroupColumns](https://distilabel.argilla.io/latest/components-gallery/steps/groupcolumns/), [FormatTextGenerationDPO](https://distilabel.argilla.io/latest/components-gallery/steps/formattextgenerationdpo/), [PreferenceToArgilla](https://distilabel.argilla.io/latest/components-gallery/steps/textgenerationtoargilla/), [InferenceEndpointsLLM](https://distilabel.argilla.io/latest/components-gallery/llms/inferenceendpointsllm/)\n",
+ "\n",
+ "![Knowledge graph figure](../../../assets/pipelines/generate-preference-dataset.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting started"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install the dependencies\n",
+ "\n",
+ "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"transformers~=4.0\" \"torch~=2.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the required imports:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from distilabel.llms import InferenceEndpointsLLM\n",
+ "from distilabel.pipeline import Pipeline\n",
+ "from distilabel.steps import (\n",
+ " LoadDataFromHub,\n",
+ " GroupColumns,\n",
+ " FormatTextGenerationDPO,\n",
+ " PreferenceToArgilla,\n",
+ ")\n",
+ "from distilabel.steps.tasks import TextGeneration, UltraFeedback"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Log in to use it directly within this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (optional) Deploy Argilla\n",
+ "\n",
+ "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/). \n",
+ "\n",
+ "Along with that, you will need to install Argilla as a distilabel extra."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To generate our preference dataset, we will need to define a `Pipeline` with all the necessary steps. Below, we will go over each step in detail."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load the dataset\n",
+ "\n",
+ "We will use as source data the [`argilla/10Kprompts-mini`](https://huggingface.co/datasets/argilla/10Kprompts-mini) dataset from the Hugging Face Hub.\n",
+ "\n",
+ "\n",
+ "\n",
+ "- Component: `LoadDataFromHub`\n",
+ "- Input columns: `instruction` and `topic`, the same as in the loaded dataset\n",
+ "- Output columns: `instruction` and `topic`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "([{'instruction': 'How can I create an efficient and robust workflow that utilizes advanced automation techniques to extract targeted data, including customer information, from diverse PDF documents and effortlessly integrate it into a designated Google Sheet? Furthermore, I am interested in establishing a comprehensive and seamless system that promptly activates an SMS notification on my mobile device whenever a new PDF document is uploaded to the Google Sheet, ensuring real-time updates and enhanced accessibility.',\n",
+ " 'topic': 'Software Development'}],\n",
+ " True)"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "load_dataset = LoadDataFromHub(\n",
+ " repo_id= \"argilla/10Kprompts-mini\",\n",
+ " num_examples=1,\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ " )\n",
+ "load_dataset.load()\n",
+ "next(load_dataset.process())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generate responses\n",
+ "\n",
+ "We need to generate the responses for the given instructions. We will use two different models available on the Hugging Face Hub through the Serverless Inference API: [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and [`mistralai/Mixtral-8x7B-Instruct-v0.1`](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1). We will also indicate the generation parameters for each model.\n",
+ "\n",
+ "- Component: `TextGeneration` task with LLMs using `InferenceEndpointsLLM`\n",
+ "- Input columns: `instruction`\n",
+ "- Output columns: `generation`, `distilabel_metadata`, `model_name` for each model\n",
+ "\n",
+ "For your use case and to improve the results, you can use any [other LLM of your choice](https://distilabel.argilla.io/latest/components-gallery/llms/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'instruction': 'Which are the top cities in Spain?', 'generation': 'Spain is a country with a rich culture, history, and architecture, and it has many great cities to visit. Here are some of the top cities in Spain:\\n\\n1. **Madrid**: The capital city of Spain, known for its vibrant nightlife, museums, and historic landmarks like the Royal Palace and Prado Museum.\\n2. **Barcelona**: The second-largest city in Spain, famous for its modernist architecture, beaches, and iconic landmarks like La Sagrada Família and Park Güell, designed by Antoni Gaudí.\\n3. **Valencia**: Located on the Mediterranean coast, Valencia is known for its beautiful beaches, City of Arts and Sciences, and delicious local cuisine, such as paella.\\n4. **Seville**: The capital of Andalusia, Seville is famous for its stunning cathedral, Royal Alcázar Palace, and lively flamenco music scene.\\n5. **Málaga**: A coastal city in southern Spain, Málaga is known for its rich history, beautiful beaches, and being the birthplace of Pablo Picasso.\\n6. **Zaragoza**: Located in the northeastern region of Aragon, Zaragoza is a city with a rich history, known for its Roman ruins, Gothic cathedral, and beautiful parks.\\n7. **Granada**: A city in the Andalusian region, Granada is famous for its stunning Alhambra palace and generalife gardens, a UNESCO World Heritage Site.\\n8. **Bilbao**: A city in the Basque Country, Bilbao is known for its modern architecture, including the Guggenheim Museum, and its rich cultural heritage.\\n9. **Alicante**: A coastal city in the Valencia region, Alicante is famous for its beautiful beaches, historic castle, and lively nightlife.\\n10. **San Sebastián**: A city in the Basque Country, San Sebastián is known for its stunning beaches, gastronomic scene, and cultural events like the San Sebastián International Film Festival.\\n\\nThese are just a few of the many great cities in Spain, each with its own unique character and attractions.', 'distilabel_metadata': {'raw_output_text_generation_0': 'Spain is a country with a rich culture, history, and architecture, and it has many great cities to visit. Here are some of the top cities in Spain:\\n\\n1. **Madrid**: The capital city of Spain, known for its vibrant nightlife, museums, and historic landmarks like the Royal Palace and Prado Museum.\\n2. **Barcelona**: The second-largest city in Spain, famous for its modernist architecture, beaches, and iconic landmarks like La Sagrada Família and Park Güell, designed by Antoni Gaudí.\\n3. **Valencia**: Located on the Mediterranean coast, Valencia is known for its beautiful beaches, City of Arts and Sciences, and delicious local cuisine, such as paella.\\n4. **Seville**: The capital of Andalusia, Seville is famous for its stunning cathedral, Royal Alcázar Palace, and lively flamenco music scene.\\n5. **Málaga**: A coastal city in southern Spain, Málaga is known for its rich history, beautiful beaches, and being the birthplace of Pablo Picasso.\\n6. **Zaragoza**: Located in the northeastern region of Aragon, Zaragoza is a city with a rich history, known for its Roman ruins, Gothic cathedral, and beautiful parks.\\n7. **Granada**: A city in the Andalusian region, Granada is famous for its stunning Alhambra palace and generalife gardens, a UNESCO World Heritage Site.\\n8. **Bilbao**: A city in the Basque Country, Bilbao is known for its modern architecture, including the Guggenheim Museum, and its rich cultural heritage.\\n9. **Alicante**: A coastal city in the Valencia region, Alicante is famous for its beautiful beaches, historic castle, and lively nightlife.\\n10. **San Sebastián**: A city in the Basque Country, San Sebastián is known for its stunning beaches, gastronomic scene, and cultural events like the San Sebastián International Film Festival.\\n\\nThese are just a few of the many great cities in Spain, each with its own unique character and attractions.'}, 'model_name': 'meta-llama/Meta-Llama-3-8B-Instruct'}]\n",
+ "[{'instruction': 'Which are the top cities in Spain?', 'generation': ' Here are some of the top cities in Spain based on various factors such as tourism, culture, history, and quality of life:\\n\\n1. Madrid: The capital and largest city in Spain, Madrid is known for its vibrant nightlife, world-class museums (such as the Prado Museum and Reina Sofia Museum), stunning parks (such as the Retiro Park), and delicious food.\\n\\n2. Barcelona: Famous for its unique architecture, Barcelona is home to several UNESCO World Heritage sites designed by Antoni Gaudí, including the Sagrada Familia and Park Güell. The city also boasts beautiful beaches, a lively arts scene, and delicious Catalan cuisine.\\n\\n3. Valencia: A coastal city located in the east of Spain, Valencia is known for its City of Arts and Sciences, a modern architectural complex that includes a planetarium, opera house, and museum of interactive science. The city is also famous for its paella, a traditional Spanish dish made with rice, vegetables, and seafood.\\n\\n4. Seville: The capital of Andalusia, Seville is famous for its flamenco dancing, stunning cathedral (the largest Gothic cathedral in the world), and the Alcázar, a beautiful palace made up of a series of rooms and courtyards.\\n\\n5. Granada: Located in the foothills of the Sierra Nevada mountains, Granada is known for its stunning Alhambra palace, a Moorish fortress that dates back to the 9th century. The city is also famous for its tapas, a traditional Spanish dish that is often served for free with drinks.\\n\\n6. Bilbao: A city in the Basque Country, Bilbao is famous for its modern architecture, including the Guggenheim Museum, a contemporary art museum designed by Frank Gehry. The city is also known for its pintxos, a type of Basque tapas that are served in bars and restaurants.\\n\\n7. Málaga: A coastal city in Andalusia, Málaga is known for its beautiful beaches, historic sites (including the Alcazaba and Gibralfaro castles), and the Picasso Museum, which is dedicated to the famous Spanish artist who was born in the city.\\n\\nThese are just a few of the many wonderful cities in Spain.', 'distilabel_metadata': {'raw_output_text_generation_0': ' Here are some of the top cities in Spain based on various factors such as tourism, culture, history, and quality of life:\\n\\n1. Madrid: The capital and largest city in Spain, Madrid is known for its vibrant nightlife, world-class museums (such as the Prado Museum and Reina Sofia Museum), stunning parks (such as the Retiro Park), and delicious food.\\n\\n2. Barcelona: Famous for its unique architecture, Barcelona is home to several UNESCO World Heritage sites designed by Antoni Gaudí, including the Sagrada Familia and Park Güell. The city also boasts beautiful beaches, a lively arts scene, and delicious Catalan cuisine.\\n\\n3. Valencia: A coastal city located in the east of Spain, Valencia is known for its City of Arts and Sciences, a modern architectural complex that includes a planetarium, opera house, and museum of interactive science. The city is also famous for its paella, a traditional Spanish dish made with rice, vegetables, and seafood.\\n\\n4. Seville: The capital of Andalusia, Seville is famous for its flamenco dancing, stunning cathedral (the largest Gothic cathedral in the world), and the Alcázar, a beautiful palace made up of a series of rooms and courtyards.\\n\\n5. Granada: Located in the foothills of the Sierra Nevada mountains, Granada is known for its stunning Alhambra palace, a Moorish fortress that dates back to the 9th century. The city is also famous for its tapas, a traditional Spanish dish that is often served for free with drinks.\\n\\n6. Bilbao: A city in the Basque Country, Bilbao is famous for its modern architecture, including the Guggenheim Museum, a contemporary art museum designed by Frank Gehry. The city is also known for its pintxos, a type of Basque tapas that are served in bars and restaurants.\\n\\n7. Málaga: A coastal city in Andalusia, Málaga is known for its beautiful beaches, historic sites (including the Alcazaba and Gibralfaro castles), and the Picasso Museum, which is dedicated to the famous Spanish artist who was born in the city.\\n\\nThese are just a few of the many wonderful cities in Spain.'}, 'model_name': 'mistralai/Mixtral-8x7B-Instruct-v0.1'}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "generate_responses = [\n",
+ " TextGeneration(\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ " ),\n",
+ " TextGeneration(\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
+ " tokenizer_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ " ),\n",
+ "]\n",
+ "for task in generate_responses:\n",
+ " task.load()\n",
+ " print(next(task.process([{\"instruction\": \"Which are the top cities in Spain?\"}])))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Group the responses\n",
+ "\n",
+ "The task to evaluate the responses needs as input a list of generations. However, each model response was saved in the generation column of the subsets `text_generation_0` and `text_generation_1`. We will combine these two columns into a single column and the `default` subset.\n",
+ "\n",
+ "- Component: `GroupColumns`\n",
+ "- Input columns: `generation` and `model_name`from `text_generation_0` and `text_generation_1`\n",
+ "- Output columns: `generations` and `model_names`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'generations': ['Madrid', 'Barcelona'],\n",
+ " 'model_names': ['meta-llama/Meta-Llama-3-8B-Instruct',\n",
+ " 'mistralai/Mixtral-8x7B-Instruct-v0.1']}]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "group_responses = GroupColumns(\n",
+ " columns=[\"generation\", \"model_name\"],\n",
+ " output_columns=[\"generations\", \"model_names\"],\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ ")\n",
+ "next(\n",
+ " group_responses.process(\n",
+ " [\n",
+ " {\n",
+ " \"generation\": \"Madrid\",\n",
+ " \"model_name\": \"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
+ " },\n",
+ " ],\n",
+ " [\n",
+ " {\n",
+ " \"generation\": \"Barcelona\",\n",
+ " \"model_name\": \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
+ " }\n",
+ " ],\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluate the responses\n",
+ "\n",
+ "To build our preference dataset, we need to evaluate the responses generated by the models. We will use [`meta-llama/Meta-Llama-3-70B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) for this, applying the `UltraFeedback` task that judges the responses according to different dimensions (helpfulness, honesty, instruction-following, truthfulness).\n",
+ "\n",
+ "- Component: `UltraFeedback` task with LLMs using `InferenceEndpointsLLM`\n",
+ "- Input columns: `instruction`, `generations`\n",
+ "- Output columns: `ratings`, `rationales`, `distilabel_metadata`, `model_name`\n",
+ "\n",
+ "For your use case and to improve the results, you can use any [other LLM of your choice](https://distilabel.argilla.io/latest/components-gallery/llms/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'instruction': \"What's the capital of Spain?\",\n",
+ " 'generations': ['Madrid', 'Barcelona'],\n",
+ " 'ratings': [5, 1],\n",
+ " 'rationales': [\"The answer is correct, directly addressing the question, and is free of hallucinations or unnecessary details. It confidently provides the accurate information, aligning perfectly with the user's intent.\",\n",
+ " \"The answer is incorrect as Barcelona is not the capital of Spain. This introduces a significant inaccuracy, failing to provide helpful information and deviating entirely from the user's intent.\"],\n",
+ " 'distilabel_metadata': {'raw_output_ultra_feedback_0': \"#### Output for Text 1\\nRating: 5 (Excellent)\\nRationale: The answer is correct, directly addressing the question, and is free of hallucinations or unnecessary details. It confidently provides the accurate information, aligning perfectly with the user's intent.\\n\\n#### Output for Text 2\\nRating: 1 (Low Quality)\\nRationale: The answer is incorrect as Barcelona is not the capital of Spain. This introduces a significant inaccuracy, failing to provide helpful information and deviating entirely from the user's intent.\"},\n",
+ " 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct'}]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "evaluate_responses = UltraFeedback(\n",
+ " aspect=\"overall-rating\",\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " pipeline=Pipeline(name=\"showcase-pipeline\"),\n",
+ ")\n",
+ "evaluate_responses.load()\n",
+ "next(\n",
+ " evaluate_responses.process(\n",
+ " [\n",
+ " {\n",
+ " \"instruction\": \"What's the capital of Spain?\",\n",
+ " \"generations\": [\"Madrid\", \"Barcelona\"],\n",
+ " }\n",
+ " ]\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Convert to a preference dataset\n",
+ "\n",
+ "- You can automatically convert it to a preference dataset with the `chosen` and `rejected` columns.\n",
+ " - Component: `FormatTextGenerationDPO` step\n",
+ " - Input columns: `instruction`, `generations`, `generation_models`, `ratings`\n",
+ " - Output columns: `prompt`, `prompt_id`, `chosen`, `chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'instruction': \"What's the capital of Spain?\",\n",
+ " 'generations': ['Madrid', 'Barcelona'],\n",
+ " 'generation_models': ['Meta-Llama-3-8B-Instruct',\n",
+ " 'Mixtral-8x7B-Instruct-v0.1'],\n",
+ " 'ratings': [5, 1],\n",
+ " 'prompt': \"What's the capital of Spain?\",\n",
+ " 'prompt_id': '26174c953df26b3049484e4721102dca6b25d2de9e3aa22aa84f25ed1c798512',\n",
+ " 'chosen': [{'role': 'user', 'content': \"What's the capital of Spain?\"},\n",
+ " {'role': 'assistant', 'content': 'Madrid'}],\n",
+ " 'chosen_model': 'Meta-Llama-3-8B-Instruct',\n",
+ " 'chosen_rating': 5,\n",
+ " 'rejected': [{'role': 'user', 'content': \"What's the capital of Spain?\"},\n",
+ " {'role': 'assistant', 'content': 'Barcelona'}],\n",
+ " 'rejected_model': 'Mixtral-8x7B-Instruct-v0.1',\n",
+ " 'rejected_rating': 1}]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "format_dpo = FormatTextGenerationDPO(pipeline=Pipeline(name=\"showcase-pipeline\"))\n",
+ "format_dpo.load()\n",
+ "next(\n",
+ " format_dpo.process(\n",
+ " [\n",
+ " {\n",
+ " \"instruction\": \"What's the capital of Spain?\",\n",
+ " \"generations\": [\"Madrid\", \"Barcelona\"],\n",
+ " \"generation_models\": [\n",
+ " \"Meta-Llama-3-8B-Instruct\",\n",
+ " \"Mixtral-8x7B-Instruct-v0.1\",\n",
+ " ],\n",
+ " \"ratings\": [5, 1],\n",
+ " }\n",
+ " ]\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- Or you can use Argilla to manually label the data and convert it to a preference dataset.\n",
+ " - Component: `PreferenceToArgilla` step\n",
+ " - Input columns: `instruction`, `generations`, `generation_models`, `ratings`\n",
+ " - Output columns: `instruction`, `generations`, `generation_models`, `ratings`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_argilla = PreferenceToArgilla(\n",
+ " dataset_name=\"preference-dataset\",\n",
+ " dataset_workspace=\"argilla\",\n",
+ " api_url=\"https://[your-owner-name]-[your-space-name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " num_generations=2\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run the pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below, you can see the full pipeline definition:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with Pipeline(name=\"generate-dataset\") as pipeline:\n",
+ "\n",
+ " load_dataset = LoadDataFromHub(repo_id=\"argilla/10Kprompts-mini\")\n",
+ "\n",
+ " generate_responses = [\n",
+ " TextGeneration(\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " )\n",
+ " ),\n",
+ " TextGeneration(\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
+ " tokenizer_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " )\n",
+ " ),\n",
+ " ]\n",
+ "\n",
+ " group_responses = GroupColumns(\n",
+ " columns=[\"generation\", \"model_name\"],\n",
+ " output_columns=[\"generations\", \"model_names\"],\n",
+ " )\n",
+ "\n",
+ " evaluate_responses = UltraFeedback(\n",
+ " aspect=\"overall-rating\",\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " format_dpo = FormatTextGenerationDPO()\n",
+ "\n",
+ " to_argilla = PreferenceToArgilla(\n",
+ " dataset_name=\"preference-dataset\",\n",
+ " dataset_workspace=\"argilla\",\n",
+ " api_url=\"https://[your-owner-name]-[your-space-name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " num_generations=2\n",
+ " )\n",
+ "\n",
+ " for task in generate_responses:\n",
+ " load_dataset.connect(task)\n",
+ " task.connect(group_responses)\n",
+ " group_responses.connect(evaluate_responses)\n",
+ " evaluate_responses.connect(format_dpo, to_argilla)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's now run the pipeline and generate the preference dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset = pipeline.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's check the preference dataset! If you have loaded the data to Argilla, you can [start annotating in the Argilla UI](https://docs.argilla.io/latest/how_to_guides/annotate/)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can push the dataset to the Hub for sharing with the community and [embed it to explore the data](https://huggingface.co/docs/hub/datasets-viewer-embed)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset.push_to_hub(\"[your-owner-name]/example-preference-dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we showcased the detailed steps to build a pipeline for generating a preference dataset using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub, or use them to train a model for DPO or ORPO.\n",
+ "\n",
+ "We used a dataset containing prompts to generate responses using two different models through the serverless Hugging Face Inference API. Next, we evaluated the responses using a third model, following the UltraFeedback standards. Finally, we converted the data to a preference dataset and used Argilla for further curation."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "distilabel-tutorials",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/finepersonas_social_ai.py b/examples/finepersonas_social_ai.py
new file mode 100644
index 0000000000..8c4f9afc73
--- /dev/null
+++ b/examples/finepersonas_social_ai.py
@@ -0,0 +1,124 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+from datasets import load_dataset
+
+from distilabel.llms import InferenceEndpointsLLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import FormatTextGenerationSFT, LoadDataFromDicts
+from distilabel.steps.tasks import TextGeneration
+
+
+class SocialAI(TextGeneration):
+ follower_type: Literal["supporter", "troll", "alarmist"] = "supporter"
+ system_prompt: str = (
+ "You are an AI assistant expert at simulating user interactions. "
+ "You must answer as if you were a '{follower_type}', be concise answer with no more than 200 characters, nothing else."
+ "Here are some traits to use for your personality:\n\n"
+ "{traits}"
+ )
+ template: str = "You are the folowing persona:\n\n{{ persona }}\n\nWhat would you say to the following?\n\n {{ post }}"
+ columns: str | list[str] = ["persona", "post"]
+
+ _follower_traits: dict[str, str] = {
+ "supporter": (
+ "- Encouraging and positive\n"
+ "- Tends to prioritize enjoyment and relaxation\n"
+ "- Focuses on the present moment and short-term pleasure\n"
+ "- Often uses humor and playful language\n"
+ "- Wants to help others feel good and have fun\n"
+ ),
+ "troll": (
+ "- Provocative and confrontational\n"
+ "- Enjoys stirring up controversy and conflict\n"
+ "- Often uses sarcasm, irony, and mocking language\n"
+ "- Tends to belittle or dismiss others' opinions and feelings\n"
+ "- Seeks to get a rise out of others and create drama\n"
+ ),
+ "alarmist": (
+ "- Anxious and warning-oriented\n"
+ "- Focuses on potential risks and negative consequences\n"
+ "- Often uses dramatic or sensational language\n"
+ "- Tends to be serious and stern in tone\n"
+ "- Seeks to alert others to potential dangers and protect them from harm (even if it's excessive or unwarranted)\n"
+ ),
+ }
+
+ def load(self) -> None:
+ super().load()
+ self.system_prompt = self.system_prompt.format(
+ follower_type=self.follower_type,
+ traits=self._follower_traits[self.follower_type],
+ )
+
+
+posts = [
+ {
+ "post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?"
+ },
+ {
+ "post": "I need to develop a training course for my company on communication skills. Need to decide how deliver it remotely."
+ },
+ {
+ "post": "I'm always 10 minutes late to meetups but no one's complained. Could this be annoying to them?"
+ },
+]
+
+personas = (
+ load_dataset("argilla/FinePersonas-v0.1-clustering-100k", split="train")
+ .shuffle()
+ .select(range(3))
+ .select_columns("persona")
+ .to_list()
+)
+
+data = []
+for post in posts:
+ for persona in personas:
+ data.append({"post": post["post"], "persona": persona["persona"]})
+
+
+with Pipeline(name="Social AI Personas") as pipeline:
+ loader = LoadDataFromDicts(data=data, batch_size=1)
+
+ llm = InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 256,
+ },
+ )
+
+ for follower_type in ["supporter", "troll", "alarmist"]:
+ follower = SocialAI(
+ llm=llm,
+ follower_type=follower_type,
+ name=f"{follower_type}_user",
+ output_mappings={"generation": f"interaction_{follower_type}"},
+ )
+ format_sft = FormatTextGenerationSFT(
+ name=f"format_sft_{follower_type}",
+ input_mappings={
+ "instruction": "post",
+ "generation": f"interaction_{follower_type}",
+ },
+ )
+ loader >> follower >> format_sft
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run(use_cache=False)
+ distiset.push_to_hub("plaguss/FinePersonas-SocialAI-test", include_script=True)
diff --git a/examples/lib_apigen.py b/examples/lib_apigen.py
new file mode 100644
index 0000000000..d49f414e68
--- /dev/null
+++ b/examples/lib_apigen.py
@@ -0,0 +1,146 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ # Tool:
+ # {"name": "final_velocity", "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", "parameters": {"initial_velocity": {"description": "The initial velocity of the object.", "type": "float"}, "acceleration": {"description": "The acceleration of the object.", "type": "float"}, "time": {"description": "The time elapsed.", "type": "float"}}}
+ # Answer:
+ # {"name": "final_velocity", "arguments": {"initial_velocity": 5, "acceleration": 1.5, "time": 40}}
+ return initial_velocity + acceleration * time
+
+
+def permutation_count(n: int, k: int) -> int:
+ """Calculates the number of permutations of k elements from a set of n elements.
+
+ Args:
+ n: The total number of elements in the set.
+ k: The number of elements to choose for the permutation.
+
+ Returns:
+ The number of permutations.
+ """
+ # Tool:
+ # {"name": "permutation_count", "description": "Calculates the number of permutations of k elements from a set of n elements.", "parameters": {"n": {"description": "The total number of elements in the set.", "type": "int"}, "k": {"description": "The number of elements to choose for the permutation.", "type": "int"}}}
+ # Answer:
+ # {"name": "permutation_count", "arguments": {"n": 10, "k": 3}}
+ import math
+
+ return math.factorial(n) / math.factorial(n - k)
+
+
+def getdivision(dividend: int, divisor: int) -> float:
+ """Divides two numbers by making an API call to a division service.
+
+ Args:
+ dividend: The dividend in the division operation.
+ divisor: The divisor in the division operation.
+
+ Returns:
+ Division of the 2 numbers.
+ """
+ # Tool:
+ # {"name": "getdivision", "description": "Divides two numbers by making an API call to a division service.", "parameters": {"divisor": {"description": "The divisor in the division operation.", "type": "int", "default": ""}, "dividend": {"description": "The dividend in the division operation.", "type": "int", "default": ""}}}
+ # Answer:
+ # {"name": "getdivision", "arguments": {"divisor": 25, "dividend": 100}}
+ return dividend / divisor
+
+
+def binary_addition(a: str, b: str) -> str:
+ """Adds two binary numbers and returns the result as a binary string.
+
+ Args:
+ a: The first binary number.
+ b: The second binary number.
+
+ Raises:
+ ValueError: On invalid binary number.
+
+ Returns:
+ Binary string of the sum of the two numbers.
+ """
+ # Tool:
+ # {"name": "binary_addition", "description": "Adds two binary numbers and returns the result as a binary string.", "parameters": {"a": {"description": "The first binary number.", "type": "str"}, "b": {"description": "The second binary number.", "type": "str"}}}
+ # Answer:
+ # {"name": "binary_addition", "arguments": {"a": "1010", "b": "1101"}}
+ if not set(a).issubset("01") or not set(b).issubset("01"):
+ raise ValueError("Invalid binary number")
+
+ return bin(int(a, 2) + int(b, 2))[2:]
+
+
+def _make_request(url: str, params: Optional[Dict[str, Any]] = None):
+ import requests
+
+ req = requests.get(url, params=params)
+ return req.json()
+
+
+def swapi_planet_resource(id: str) -> Dict[str, Any]:
+ """get a specific planets resource
+
+ Args:
+ id: identifier of the planet
+
+ Returns:
+ Information about the planet.
+ """
+ # url = "https://swapi.dev/api/planets/1"
+ return _make_request(r"https://swapi.dev/api/planets/", params={"id": id})
+
+
+def disney_character(name: str) -> Dict[str, Any]:
+ """Find a specific character using this endpoint
+
+ Args:
+ name: Name of the character to look for.
+
+ Returns:
+ Infrmation about the character.
+ """
+ # Example:
+ # url = "https://api.disneyapi.dev/character"
+ # params = {"name": "mulan"}
+ return _make_request(r"https://api.disneyapi.dev/character", params={"name": name})
+
+
+def get_lib():
+ return {
+ "swapi_planet_resource": swapi_planet_resource,
+ "disney_character": disney_character,
+ "final_velocity": final_velocity,
+ "permutation_count": permutation_count,
+ "getdivision": getdivision,
+ "binary_addition": binary_addition,
+ }
+
+
+def get_tools() -> Dict[str, Dict[str, Any]]:
+ """Returns the tool representation of the functions in the library."""
+ # TODO: Improve the `get_json_schema`, it fails on a lot of examples.
+ from transformers.utils import get_json_schema
+
+ return {name: get_json_schema(func) for name, func in get_lib().items()}
diff --git a/examples/pipeline_apigen.py b/examples/pipeline_apigen.py
new file mode 100644
index 0000000000..e63e16e39e
--- /dev/null
+++ b/examples/pipeline_apigen.py
@@ -0,0 +1,116 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from datasets import load_dataset
+
+from distilabel.llms import InferenceEndpointsLLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import CombineOutputs, DataSampler, LoadDataFromDicts
+from distilabel.steps.tasks import (
+ APIGenExecutionChecker,
+ APIGenGenerator,
+ APIGenSemanticChecker,
+)
+from distilabel.steps.tasks.apigen.utils import PrepareExamples, load_module_from_path
+
+libpath = Path(__file__).parent / "lib_apigen.py"
+
+data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ {
+ "func_name": "binary_addition",
+ "func_desc": "Adds two binary numbers and returns the result as a binary string.",
+ },
+ {
+ "func_name": "swapi_planet_resource",
+ "func_desc": "get a specific planets resource",
+ },
+ {
+ "func_name": "disney_character",
+ "func_desc": "Find a specific character using this endpoint",
+ },
+]
+
+libpath_module = load_module_from_path(libpath)
+tools = libpath_module.get_tools() # call get_tools()
+
+# TODO: Add in the tools between 0 and 2 extra tools to make the task more challenging.
+for row in data:
+ # The tools should have a mix where both the correct and irrelevant tools are present.
+ row.update({"tools": [tools[row["func_name"]]]})
+
+
+ds_og = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+)
+
+
+with Pipeline(name="APIGenPipeline") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data)
+ sampler = DataSampler(
+ data=ds_og,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+
+ prep_examples = PrepareExamples()
+
+ model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
+ llm = InferenceEndpointsLLM(
+ model_id=model_id,
+ tokenizer_id=model_id,
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 2048,
+ },
+ )
+ apigen = APIGenGenerator(
+ llm=llm,
+ use_default_structured_output=True,
+ )
+ combine_steps = CombineOutputs()
+
+ execution_checker = APIGenExecutionChecker(libpath=str(libpath))
+ semantic_checker = APIGenSemanticChecker(llm=llm)
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ >> apigen
+ >> execution_checker
+ >> semantic_checker
+ )
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run()
+ print(distiset["default"]["train"][0])
diff --git a/mkdocs.yml b/mkdocs.yml
index 5c24a8839c..69aaeed275 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -7,6 +7,7 @@ site_description: Distilabel is an AI Feedback (AIF) framework for building data
# Repository
repo_name: argilla-io/distilabel
repo_url: https://github.com/argilla-io/distilabel
+edit_uri: edit/main/docs/
extra:
version:
@@ -21,8 +22,8 @@ extra:
- icon: fontawesome/brands/discord
link: http://hf.co/join/discord
analytics:
- provider: google
- property: G-PPKL7LMWCE
+ provider: plausible
+ domain: distilabel.argilla.io
feedback:
title: Was this page helpful?
ratings:
@@ -135,61 +136,86 @@ plugins:
- mkdocstrings:
handlers:
python:
- selection:
- inherited_members: true # Allow looking up inherited methods
+ setup_commands:
+ - import sys; sys.path.insert(0, 'src') # API references are built from source
options:
- show_protected_members: true
- show_private_members: true
- rendering:
- show_root_heading: true # actually display anything at all...
- # show_root_full_path: true # display "diffrax.asdf" not just "asdf"
- show_if_no_docstring: true
- show_signature_annotations: true
- show_source: false # don't include source code
+ show_inheritance_diagram: false
+ show_source: true # include source code
+ # Headings
+ heading_level: 3
+ show_root_heading: true # show the python path of the class
+ show_root_toc_entry: true # show the toc entry for the root class
+ show_root_full_path: false # display "diffrax.asdf" not just "asdf"
+ show_object_full_path: false # display "diffrax.asdf" not just "asdf"
+ show_symbol_type_heading: true
+ show_symbol_type_toc: true
+ # Members
+ inherited_members: false # allow looking up inherited methods
members_order: source # order methods according to their order of definition in the source code, not alphabetical order
- heading_level: 4
+ show_labels: true
+ # Docstring
+ docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
+ show_if_no_docstring: false
+ # Signature
+ separate_signature: false
+ show_signature_annotations: false
- social
+ - mknotebooks
+ - material-plausible
+ - glightbox
- distilabel/components-gallery:
add_after_page: How-to guides
nav:
- Distilabel: "index.md"
- Getting started:
- - Installation: "sections/getting_started/installation.md"
- Quickstart: "sections/getting_started/quickstart.md"
+ - Installation: "sections/getting_started/installation.md"
- FAQ: "sections/getting_started/faq.md"
- How-to guides:
- "sections/how_to_guides/index.md"
- Basic:
- - Define Steps for your Pipeline:
+ - Steps for processing data:
- "sections/how_to_guides/basic/step/index.md"
- GeneratorStep: "sections/how_to_guides/basic/step/generator_step.md"
- GlobalStep: "sections/how_to_guides/basic/step/global_step.md"
- - Define Tasks that rely on LLMs:
+ - Tasks for generating and judging with LLMs:
- "sections/how_to_guides/basic/task/index.md"
- GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md"
- - Define LLMs as local or remote models: "sections/how_to_guides/basic/llm/index.md"
+ - Executing Tasks with LLMs: "sections/how_to_guides/basic/llm/index.md"
- Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md"
- Advanced:
- - Using the Distiset dataset object: "sections/how_to_guides/advanced/distiset.md"
- - Cache and recover pipeline executions: "sections/how_to_guides/advanced/caching.md"
- - Export data to Argilla: "sections/how_to_guides/advanced/argilla.md"
+ - The Distiset dataset object: "sections/how_to_guides/advanced/distiset.md"
+ - Pipeline cache: "sections/how_to_guides/advanced/caching.md"
+ - Exporting data to Argilla: "sections/how_to_guides/advanced/argilla.md"
- Structured data generation: "sections/how_to_guides/advanced/structured_generation.md"
- - Specify requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md"
+ - Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md"
+ - Specifying requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md"
- Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md"
- Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md"
- Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md"
+ - Saving step generated artifacts: "sections/how_to_guides/advanced/saving_step_generated_artifacts.md"
- Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md"
- Scaling and distributing a pipeline with Ray: "sections/how_to_guides/advanced/scaling_with_ray.md"
- - Pipeline Samples:
- - Examples: "sections/pipeline_samples/examples/index.md"
+ - Tutorials:
+ - "sections/pipeline_samples/index.md"
+ - Tutorials:
+ - Generate a preference dataset: "sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb"
+ - Clean an existing preference dataset: "sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb"
+ - Synthetic data generation for fine-tuning custom retrieval and reranking models: "sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb"
- Papers:
- - "sections/pipeline_samples/papers/index.md"
+ - DeepSeek Prover: "sections/pipeline_samples/papers/deepseek_prover.md"
- DEITA: "sections/pipeline_samples/papers/deita.md"
- Instruction Backtranslation: "sections/pipeline_samples/papers/instruction_backtranslation.md"
- Prometheus 2: "sections/pipeline_samples/papers/prometheus.md"
- UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md"
- - DeepSeek Prover: "sections/pipeline_samples/papers/deepseek_prover.md"
+ - APIGen: "sections/pipeline_samples/papers/apigen.md"
+ - CLAIR: "sections/pipeline_samples/papers/clair.md"
+ - Examples:
+ - Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md"
+ - Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md"
+ - Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md"
+ - Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md"
- API Reference:
- Step:
- "api/step/index.md"
@@ -202,38 +228,32 @@ nav:
- Hugging Face: "api/step_gallery/hugging_face.md"
- Columns: "api/step_gallery/columns.md"
- Extra: "api/step_gallery/extra.md"
+ - Typing: "api/step/typing.md"
- Task:
- "api/task/index.md"
- GeneratorTask: "api/task/generator_task.md"
- - Task Gallery: "api/task_gallery/index.md"
+ - Task Gallery: "api/task/task_gallery.md"
- Typing: "api/task/typing.md"
- LLM:
- "api/llm/index.md"
- - LLM Gallery:
- - Anthropic: "api/llm/anthropic.md"
- - Anyscale: "api/llm/anyscale.md"
- - Azure (via OpenAI): "api/llm/azure.md"
- - Cohere: "api/llm/cohere.md"
- - Groq: "api/llm/groq.md"
- - Hugging Face: "api/llm/huggingface.md"
- - LiteLLM: "api/llm/litellm.md"
- - llama.cpp: "api/llm/llamacpp.md"
- - Mistral: "api/llm/mistral.md"
- - Ollama: "api/llm/ollama.md"
- - OpenAI: "api/llm/openai.md"
- - Together AI: "api/llm/together.md"
- - Google Vertex AI: "api/llm/vertexai.md"
- - vLLM: "api/llm/vllm.md"
+ - LLM Gallery: "api/llm/llm_gallery.md"
+ - Embedding:
+ - "api/embedding/index.md"
+ - Embedding Gallery: "api/embedding/embedding_gallery.md"
- Pipeline:
- "api/pipeline/index.md"
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Typing: "api/pipeline/typing.md"
- - Utils: "api/pipeline/utils.md"
+ - Step Wrapper: "api/pipeline/step_wrapper.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
+ - Exceptions: "api/exceptions.md"
+ - Errors: "api/errors.md"
- Distiset: "api/distiset.md"
- CLI: "api/cli.md"
- Community:
- sections/community/index.md
+ - How to contribute?: sections/community/contributor.md
+ - Developer Documentation: sections/community/developer_documentation.md
- Issue dashboard: sections/community/popular_issues.md
diff --git a/pyproject.toml b/pyproject.toml
index 393941fe2c..bf9550c9f0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,11 +55,14 @@ docs = [
"mkdocs-literate-nav >= 0.6.1",
"mkdocs-section-index >= 0.3.8",
"mkdocs-gen-files >= 0.5.0",
+ "mkdocs-glightbox >= 0.4.0",
+ "material-plausible-plugin>=0.2.0",
"mike >= 2.0.0",
"Pillow >= 9.5.0",
"CairoSVG >= 2.7.1",
"mknotebooks >= 0.8.0",
"pandas >= 2.0",
+ "tabulate>=0.9.0",
]
tests = [
"pytest >= 7.4.0",
@@ -79,10 +82,10 @@ hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
-mistralai = ["mistralai >= 0.1.0"]
+mistralai = ["mistralai >= 1.0.0"]
ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
-outlines = ["outlines >= 0.0.40"]
+outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
@@ -93,7 +96,15 @@ vllm = [
]
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
-faiss-gpu = ["faiss-cpu >= 1.7.2"]
+faiss-gpu = ["faiss-gpu >= 1.7.2"]
+text-clustering = [
+ "umap-learn >= 0.5.6",
+ "scikit-learn >= 1.4.1",
+ "matplotlib >= 3.8.3", # For the figure (even though it's optional)
+]
+
+# minhash
+minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
[project.urls]
Documentation = "https://distilabel.argilla.io/"
diff --git a/scripts/install_cpu_vllm.sh b/scripts/install_cpu_vllm.sh
index 199413a1e2..bdaa7ad74e 100755
--- a/scripts/install_cpu_vllm.sh
+++ b/scripts/install_cpu_vllm.sh
@@ -4,7 +4,7 @@ set -e
echo "Updating system and installing build dependencies..."
sudo apt-get update -y
-sudo apt-get install -y gcc-12 g++-12 libnuma-dev cmake
+sudo apt-get install -y gcc-12 g++-12 libnuma-dev cmake libdnnl-dev
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
echo "Python version:"
@@ -15,7 +15,7 @@ which python
echo "Installing Python build dependencies..."
python -m pip install --upgrade pip
-python -m pip install wheel packaging ninja "setuptools>=49.4.0" numpy
+python -m pip install wheel packaging ninja "setuptools>=49.4.0" numpy setuptools-scm
echo "Cloning 'vllm-project/vllm' GitHub repository..."
git clone https://github.com/vllm-project/vllm.git
diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh
index 328c5fe029..0b2277f0fb 100755
--- a/scripts/install_dependencies.sh
+++ b/scripts/install_dependencies.sh
@@ -6,13 +6,12 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])")
python -m pip install uv
-uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu]"
+uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"
if [ "${python_version}" != "(3, 12)" ]; then
- uv pip install --system -e .[ray]
+ uv pip install --system -e .[ray]
fi
./scripts/install_cpu_vllm.sh
-uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git
uv pip install --system -e ".[dev,tests]"
diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py
index 8ee519f077..bafff914bf 100644
--- a/src/distilabel/__init__.py
+++ b/src/distilabel/__init__.py
@@ -14,6 +14,6 @@
from rich import traceback as rich_traceback
-__version__ = "1.3.2"
+__version__ = "1.4.0"
rich_traceback.install(show_locals=True)
diff --git a/src/distilabel/cli/pipeline/utils.py b/src/distilabel/cli/pipeline/utils.py
index 2f8f9170a4..42204c1605 100644
--- a/src/distilabel/cli/pipeline/utils.py
+++ b/src/distilabel/cli/pipeline/utils.py
@@ -24,6 +24,7 @@
from pydantic.type_adapter import TypeAdapter
from distilabel.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME
+from distilabel.errors import DistilabelUserError
from distilabel.pipeline.local import Pipeline
if TYPE_CHECKING:
@@ -106,8 +107,9 @@ def get_config_from_url(url: str) -> Dict[str, Any]:
ValueError: If the file format is not supported.
"""
if not url.endswith((".json", ".yaml", ".yml")):
- raise ValueError(
- f"Unsupported file format for '{url}'. Only JSON and YAML are supported"
+ raise DistilabelUserError(
+ f"Unsupported file format for '{url}'. Only JSON and YAML are supported",
+ page="sections/how_to_guides/basic/pipeline/?h=seriali#serializing-the-pipeline",
)
response = _download_remote_file(url)
@@ -134,8 +136,9 @@ def get_pipeline_from_url(url: str, pipeline_name: str = "pipeline") -> "BasePip
ValueError: If the file format is not supported.
"""
if not url.endswith(".py"):
- raise ValueError(
- f"Unsupported file format for '{url}'. It must be a python file."
+ raise DistilabelUserError(
+ f"Unsupported file format for '{url}'. It must be a python file.",
+ page="sections/how_to_guides/advanced/cli/#distilabel-pipeline-run",
)
response = _download_remote_file(url)
@@ -179,8 +182,9 @@ def get_pipeline(
elif config_or_script.endswith(".py"):
script = config_or_script
else:
- raise ValueError(
- "The file must be a valid config file or python script with a pipeline."
+ raise DistilabelUserError(
+ "The file must be a valid config file or python script with a pipeline.",
+ page="sections/how_to_guides/advanced/cli/#distilabel-pipeline-run",
)
if valid_http_url(config_or_script):
diff --git a/src/distilabel/constants.py b/src/distilabel/constants.py
index a1400bcd03..44554f8423 100644
--- a/src/distilabel/constants.py
+++ b/src/distilabel/constants.py
@@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from pathlib import Path
from typing import Final
# Steps related constants
DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata"
-# Pipeline related constants
+# Cache
+BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel"
+PIPELINES_CACHE_DIR = BASE_CACHE_DIR / "pipelines"
+
+# Pipeline dag related constants
STEP_ATTR_NAME: Final[str] = "step"
INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
@@ -25,13 +30,41 @@
CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"
+# Pipeline execution related constants
+PIPELINE_NAME_ENV_NAME = "DISTILABEL_PIPELINE_NAME"
+PIPELINE_CACHE_ID_ENV_NAME = "DISTILABEL_PIPELINE_CACHE_ID"
+SIGINT_HANDLER_CALLED_ENV_NAME = "sigint_handler_called"
+
+# Data paths constants
+STEPS_OUTPUTS_PATH = "steps_outputs"
+STEPS_ARTIFACTS_PATH = "steps_artifacts"
+
+# Distiset related constants
+DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs"
+DISTISET_ARTIFACTS_FOLDER: Final[str] = "artifacts"
+PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml"
+PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log"
+
+# Docs page for the custom errors
+DISTILABEL_DOCS_URL: Final[str] = "https://distilabel.argilla.io/latest/"
+
__all__ = [
+ "DISTILABEL_METADATA_KEY",
+ "BASE_CACHE_DIR",
+ "PIPELINES_CACHE_DIR",
"STEP_ATTR_NAME",
"INPUT_QUEUE_ATTR_NAME",
"RECEIVES_ROUTED_BATCHES_ATTR_NAME",
"ROUTING_BATCH_FUNCTION_ATTR_NAME",
"CONVERGENCE_STEP_ATTR_NAME",
"LAST_BATCH_SENT_FLAG",
- "DISTILABEL_METADATA_KEY",
+ "SIGINT_HANDLER_CALLED_ENV_NAME",
+ "STEPS_OUTPUTS_PATH",
+ "STEPS_ARTIFACTS_PATH",
+ "DISTISET_CONFIG_FOLDER",
+ "DISTISET_ARTIFACTS_FOLDER",
+ "PIPELINE_CONFIG_FILENAME",
+ "PIPELINE_LOG_FILENAME",
+ "DISTILABEL_DOCS_URL",
]
diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py
index 92eedecfca..8e52c667d3 100644
--- a/src/distilabel/distiset.py
+++ b/src/distilabel/distiset.py
@@ -12,24 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
import os.path as posixpath
import re
import sys
+from collections import defaultdict
from os import PathLike
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
import fsspec
import yaml
from datasets import Dataset, load_dataset, load_from_disk
from datasets.filesystems import is_remote_filesystem
-from huggingface_hub import DatasetCardData, HfApi, upload_file
+from huggingface_hub import DatasetCardData, HfApi, upload_file, upload_folder
from huggingface_hub.file_download import hf_hub_download
from pyarrow.lib import ArrowInvalid
from typing_extensions import Self
-from distilabel.constants import STEP_ATTR_NAME
+from distilabel.constants import (
+ DISTISET_ARTIFACTS_FOLDER,
+ DISTISET_CONFIG_FOLDER,
+ PIPELINE_CONFIG_FILENAME,
+ PIPELINE_LOG_FILENAME,
+ STEP_ATTR_NAME,
+ STEPS_ARTIFACTS_PATH,
+ STEPS_OUTPUTS_PATH,
+)
from distilabel.utils.card.dataset_card import (
DistilabelDatasetCard,
size_categories_parser,
@@ -42,11 +52,6 @@
from distilabel.pipeline._dag import DAG
-DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs"
-PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml"
-PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log"
-
-
class Distiset(dict):
"""Convenient wrapper around `datasets.Dataset` to push to the Hugging Face Hub.
@@ -54,12 +59,18 @@ class Distiset(dict):
`DAG` and the values are `datasets.Dataset`.
Attributes:
- pipeline_path: Optional path to the pipeline.yaml file that generated the dataset.
- log_filename_path: Optional path to the pipeline.log file that generated was written by the
- pipeline.
+ _pipeline_path: Optional path to the `pipeline.yaml` file that generated the dataset.
+ Defaults to `None`.
+ _artifacts_path: Optional path to the directory containing the generated artifacts
+ by the pipeline steps. Defaults to `None`.
+ _log_filename_path: Optional path to the `pipeline.log` file that generated was written
+ by the pipeline. Defaults to `None`.
+ _citations: Optional list containing citations that will be included in the dataset
+ card. Defaults to `None`.
"""
_pipeline_path: Optional[Path] = None
+ _artifacts_path: Optional[Path] = None
_log_filename_path: Optional[Path] = None
_citations: Optional[List[str]] = None
@@ -121,6 +132,16 @@ def push_to_hub(
**kwargs,
)
+ if self.artifacts_path:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=self.artifacts_path,
+ path_in_repo="artifacts",
+ token=token,
+ repo_type="dataset",
+ commit_message="Include pipeline artifacts",
+ )
+
if include_script and script_path.exists():
upload_file(
path_or_fileobj=script_path,
@@ -128,7 +149,7 @@ def push_to_hub(
repo_id=repo_id,
repo_type="dataset",
token=token,
- commit_message="Include pipeline script.",
+ commit_message="Include pipeline script",
)
if generate_card:
@@ -185,11 +206,38 @@ def _get_card(
sample_records=sample_records,
include_script=include_script,
filename_py=filename_py,
+ artifacts=self._get_artifacts_metadata(),
references=self.citations,
)
return card
+ def _get_artifacts_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
+ """Gets a dictionary with the metadata of the artifacts generated by the pipeline steps.
+
+ Returns:
+ A dictionary in which the key is the name of the step and the value is a list
+ of dictionaries, each of them containing the name and metadata of the step artifact.
+ """
+ if not self.artifacts_path:
+ return {}
+
+ def iterdir_ignore_hidden(path: Path) -> Generator[Path, None, None]:
+ return (f for f in Path(path).iterdir() if not f.name.startswith("."))
+
+ artifacts_metadata = defaultdict(list)
+ for step_artifacts_dir in iterdir_ignore_hidden(self.artifacts_path):
+ step_name = step_artifacts_dir.stem
+ for artifact_dir in iterdir_ignore_hidden(step_artifacts_dir):
+ artifact_name = artifact_dir.stem
+ metadata_path = artifact_dir / "metadata.json"
+ metadata = json.loads(metadata_path.read_text())
+ artifacts_metadata[step_name].append(
+ {"name": artifact_name, "metadata": metadata}
+ )
+
+ return dict(artifacts_metadata)
+
def _extract_readme_metadata(
self, repo_id: str, token: Optional[str]
) -> Dict[str, Any]:
@@ -243,6 +291,7 @@ def _generate_card(
repo_type="dataset",
token=token,
)
+
if self.pipeline_path:
# If the pipeline.yaml is available, upload it to the Hugging Face Hub as well.
HfApi().upload_file(
@@ -252,6 +301,7 @@ def _generate_card(
repo_type="dataset",
token=token,
)
+
if self.log_filename_path:
# The same we had with "pipeline.yaml" but with the log file.
HfApi().upload_file(
@@ -329,17 +379,17 @@ def save_to_disk(
Examples:
```python
# Save your distiset in a local folder:
- >>> distiset.save_to_disk(distiset_path="my-distiset")
+ distiset.save_to_disk(distiset_path="my-distiset")
# Save your distiset in a remote storage:
- >>> storage_options = {
- ... "key": os.environ["S3_ACCESS_KEY"],
- ... "secret": os.environ["S3_SECRET_KEY"],
- ... "client_kwargs": {
- ... "endpoint_url": os.environ["S3_ENDPOINT_URL"],
- ... "region_name": os.environ["S3_REGION"],
- ... },
- ... }
- >>> distiset.save_to_disk(distiset_path="my-distiset", storage_options=storage_options)
+ storage_options = {
+ "key": os.environ["S3_ACCESS_KEY"],
+ "secret": os.environ["S3_SECRET_KEY"],
+ "client_kwargs": {
+ "endpoint_url": os.environ["S3_ENDPOINT_URL"],
+ "region_name": os.environ["S3_REGION"],
+ },
+ }
+ distiset.save_to_disk(distiset_path="my-distiset", storage_options=storage_options)
```
"""
distiset_path = str(distiset_path)
@@ -360,6 +410,12 @@ def save_to_disk(
)
fs.makedirs(distiset_config_folder, exist_ok=True)
+ if self.artifacts_path:
+ distiset_artifacts_folder = posixpath.join(
+ distiset_path, DISTISET_ARTIFACTS_FOLDER
+ )
+ fs.copy(str(self.artifacts_path), distiset_artifacts_folder, recursive=True)
+
if save_card:
# NOTE: Currently the card is not the same if we write to disk or push to the HF hub,
# as we aren't generating the README copying/updating the data from the dataset repo.
@@ -415,7 +471,7 @@ def load_from_disk(
original_distiset_path = str(distiset_path)
fs: fsspec.AbstractFileSystem
- fs, _, [distiset_path] = fsspec.get_fs_token_paths(
+ fs, _, [distiset_path] = fsspec.get_fs_token_paths( # type: ignore
original_distiset_path, storage_options=storage_options
)
dest_distiset_path = distiset_path
@@ -425,6 +481,7 @@ def load_from_disk(
), "`distiset_path` must be a `PathLike` object pointing to a folder or a URI of a remote filesystem."
has_config = False
+ has_artifacts = False
distiset = cls()
if is_remote_filesystem(fs):
@@ -432,19 +489,23 @@ def load_from_disk(
if download_dir:
dest_distiset_path = download_dir
else:
- dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path)
- fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True)
+ dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) # type: ignore
+ fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) # type: ignore
# Now we should have the distiset locally, so we can read those files
for folder in Path(dest_distiset_path).iterdir():
if folder.stem == DISTISET_CONFIG_FOLDER:
has_config = True
continue
+ elif folder.stem == DISTISET_ARTIFACTS_FOLDER:
+ has_artifacts = True
+ continue
distiset[folder.stem] = load_from_disk(
str(folder),
keep_in_memory=keep_in_memory,
)
- # From the config folder we just need to point to the files. Once downloaded we set the path
+
+ # From the config folder we just need to point to the files. Once downloaded we set the path to point to point to the files. Once downloaded we set the path
# to wherever they are.
if has_config:
distiset_config_folder = posixpath.join(
@@ -463,6 +524,11 @@ def load_from_disk(
if Path(log_filename_path).exists():
distiset.log_filename_path = Path(log_filename_path)
+ if has_artifacts:
+ distiset.artifacts_path = Path(
+ posixpath.join(dest_distiset_path, DISTISET_ARTIFACTS_FOLDER)
+ )
+
return distiset
@property
@@ -474,6 +540,16 @@ def pipeline_path(self) -> Union[Path, None]:
def pipeline_path(self, path: PathLike) -> None:
self._pipeline_path = Path(path)
+ @property
+ def artifacts_path(self) -> Union[Path, None]:
+ """Returns the path to the directory containing the artifacts generated by the steps
+ of the pipeline."""
+ return self._artifacts_path
+
+ @artifacts_path.setter
+ def artifacts_path(self, path: PathLike) -> None:
+ self._artifacts_path = Path(path)
+
@property
def log_filename_path(self) -> Union[Path, None]:
"""Returns the path to the `pipeline.log` file that generated the `Pipeline`."""
@@ -530,20 +606,19 @@ def create_distiset( # noqa: C901
correspond to different configurations of the dataset.
Examples:
-
```python
- >>> from pathlib import Path
- >>> distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname")
+ from pathlib import Path
+ distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname")
```
"""
from distilabel.constants import DISTILABEL_METADATA_KEY
logger = logging.getLogger("distilabel.distiset")
- data_dir = Path(data_dir)
+ steps_outputs_dir = data_dir / STEPS_OUTPUTS_PATH
distiset = Distiset()
- for file in data_dir.iterdir():
+ for file in steps_outputs_dir.iterdir():
if file.is_file():
continue
@@ -569,19 +644,26 @@ def create_distiset( # noqa: C901
if len(distiset.keys()) == 1:
distiset["default"] = distiset.pop(list(distiset.keys())[0])
+ # If there's any artifact set the `artifacts_path` so they can be uploaded
+ steps_artifacts_dir = data_dir / STEPS_ARTIFACTS_PATH
+ if any(steps_artifacts_dir.rglob("*")):
+ distiset.artifacts_path = steps_artifacts_dir
+
+ # Include `pipeline.yaml` if exists
if pipeline_path:
distiset.pipeline_path = pipeline_path
else:
# If the pipeline path is not provided, try to find it in the parent directory
# and assume that's the wanted file.
- pipeline_path = data_dir.parent / "pipeline.yaml"
+ pipeline_path = steps_outputs_dir.parent / "pipeline.yaml"
if pipeline_path.exists():
distiset.pipeline_path = pipeline_path
+ # Include `pipeline.log` if exists
if log_filename_path:
distiset.log_filename_path = log_filename_path
else:
- log_filename_path = data_dir.parent / "pipeline.log"
+ log_filename_path = steps_outputs_dir.parent / "pipeline.log"
if log_filename_path.exists():
distiset.log_filename_path = log_filename_path
@@ -613,8 +695,9 @@ def _grab_citations(dag: "DAG") -> List[str]:
for ref in references.values():
try:
bibtex_refs.append(get_bibtex(ref))
- except ValueError as e:
- print(f"Error: {e}")
+ except ValueError:
+ # No need to inform in this case, it's noise
+ pass
except AttributeError as e:
print(
f"Couldn't obtain the bibtex format for the ref: '{ref}', error: {e}"
diff --git a/src/distilabel/embeddings/__init__.py b/src/distilabel/embeddings/__init__.py
index a7e5e63e2c..190ea70e50 100644
--- a/src/distilabel/embeddings/__init__.py
+++ b/src/distilabel/embeddings/__init__.py
@@ -14,8 +14,10 @@
from distilabel.embeddings.base import Embeddings
from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
+from distilabel.embeddings.vllm import vLLMEmbeddings
__all__ = [
"Embeddings",
"SentenceTransformerEmbeddings",
+ "vLLMEmbeddings",
]
diff --git a/src/distilabel/embeddings/sentence_transformers.py b/src/distilabel/embeddings/sentence_transformers.py
index 08b3465ad1..85baea3de9 100644
--- a/src/distilabel/embeddings/sentence_transformers.py
+++ b/src/distilabel/embeddings/sentence_transformers.py
@@ -55,7 +55,6 @@ class SentenceTransformerEmbeddings(Embeddings, CudaDevicePlacementMixin):
of 1. Defaults to `None`.
Examples:
-
Generating sentence embeddings:
```python
diff --git a/src/distilabel/embeddings/vllm.py b/src/distilabel/embeddings/vllm.py
new file mode 100644
index 0000000000..cbbadd69af
--- /dev/null
+++ b/src/distilabel/embeddings/vllm.py
@@ -0,0 +1,129 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from pydantic import Field, PrivateAttr
+
+from distilabel.embeddings.base import Embeddings
+from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+
+if TYPE_CHECKING:
+ from vllm import LLM as _vLLM
+
+
+class vLLMEmbeddings(Embeddings, CudaDevicePlacementMixin):
+ """`vllm` library implementation for embedding generation.
+
+ Attributes:
+ model: the model Hugging Face Hub repo id or a path to a directory containing the
+ model weights and configuration files.
+ dtype: the data type to use for the model. Defaults to `auto`.
+ trust_remote_code: whether to trust the remote code when loading the model. Defaults
+ to `False`.
+ quantization: the quantization mode to use for the model. Defaults to `None`.
+ revision: the revision of the model to load. Defaults to `None`.
+ enforce_eager: whether to enforce eager execution. Defaults to `True`.
+ seed: the seed to use for the random number generator. Defaults to `0`.
+ extra_kwargs: additional dictionary of keyword arguments that will be passed to the
+ `LLM` class of `vllm` library. Defaults to `{}`.
+ _model: the `vLLM` model instance. This attribute is meant to be used internally
+ and should not be accessed directly. It will be set in the `load` method.
+
+ References:
+ - [Offline inference embeddings](https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_embedding.html)
+
+ Examples:
+ Generating sentence embeddings:
+
+ ```python
+ from distilabel.embeddings import vLLMEmbeddings
+
+ embeddings = vLLMEmbeddings(model="intfloat/e5-mistral-7b-instruct")
+
+ embeddings.load()
+
+ results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
+ # [
+ # [-0.05447685346007347, -0.01623094454407692, ...],
+ # [4.4889533455716446e-05, 0.044016145169734955, ...],
+ # ]
+ ```
+ """
+
+ model: str
+ dtype: str = "auto"
+ trust_remote_code: bool = False
+ quantization: Optional[str] = None
+ revision: Optional[str] = None
+
+ enforce_eager: bool = True
+
+ seed: int = 0
+
+ extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
+ default_factory=dict,
+ description="Additional dictionary of keyword arguments that will be passed to the"
+ " `vLLM` class of `vllm` library. See all the supported arguments at: "
+ "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
+ )
+
+ _model: "_vLLM" = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id."""
+ super().load()
+
+ CudaDevicePlacementMixin.load(self)
+
+ try:
+ from vllm import LLM as _vLLM
+
+ except ImportError as ie:
+ raise ImportError(
+ "vLLM is not installed. Please install it using `pip install vllm`."
+ ) from ie
+
+ self._model = _vLLM(
+ self.model,
+ dtype=self.dtype,
+ trust_remote_code=self.trust_remote_code,
+ quantization=self.quantization,
+ revision=self.revision,
+ enforce_eager=self.enforce_eager,
+ seed=self.seed,
+ **self.extra_kwargs, # type: ignore
+ )
+
+ def unload(self) -> None:
+ """Unloads the `vLLM` model."""
+ CudaDevicePlacementMixin.unload(self)
+ super().unload()
+
+ @property
+ def model_name(self) -> str:
+ """Returns the name of the model."""
+ return self.model
+
+ def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
+ """Generates embeddings for the provided inputs.
+
+ Args:
+ inputs: a list of texts for which an embedding has to be generated.
+
+ Returns:
+ The generated embeddings.
+ """
+ return [output.outputs.embedding for output in self._model.encode(inputs)]
diff --git a/src/distilabel/envs.py b/src/distilabel/envs.py
new file mode 100644
index 0000000000..500c736e52
--- /dev/null
+++ b/src/distilabel/envs.py
@@ -0,0 +1,52 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Idea from: https://github.com/vllm-project/vllm/blob/main/vllm/envs.py
+
+import os
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+
+from distilabel import constants
+
+if TYPE_CHECKING:
+ DISTILABEL_LOG_LEVEL: str = "INFO"
+ DISTILABEL_PIPELINE_NAME: Optional[str] = None
+ DISTILABEL_PIPELINE_CACHE_ID: Optional[str] = None
+ DISTILABEL_CACHE_DIR: Optional[str] = None
+
+ENVIRONMENT_VARIABLES: Dict[str, Callable[[], Any]] = {
+ # `distilabel` logging level.
+ "DISTILABEL_LOG_LEVEL": lambda: os.getenv("DISTILABEL_LOG_LEVEL", "INFO").upper(),
+ # The name of the `distilabel` pipeline currently running.
+ constants.PIPELINE_NAME_ENV_NAME: lambda: os.getenv(
+ constants.PIPELINE_NAME_ENV_NAME, None
+ ),
+ # The cache ID of the `distilabel` pipeline currently running.
+ constants.PIPELINE_CACHE_ID_ENV_NAME: lambda: os.getenv(
+ constants.PIPELINE_CACHE_ID_ENV_NAME, None
+ ),
+ # The cache ID of the `distilabel` pipeline currently running.
+ "DISTILABEL_CACHE_DIR": lambda: os.getenv("DISTILABEL_CACHE_DIR", None),
+}
+
+
+def __getattr__(name: str) -> Any:
+ # lazy evaluation of environment variables
+ if name in ENVIRONMENT_VARIABLES:
+ return ENVIRONMENT_VARIABLES[name]()
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+
+
+def __dir__() -> List[str]:
+ return list(ENVIRONMENT_VARIABLES.keys())
diff --git a/src/distilabel/errors.py b/src/distilabel/errors.py
new file mode 100644
index 0000000000..71603aed7f
--- /dev/null
+++ b/src/distilabel/errors.py
@@ -0,0 +1,67 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from distilabel.constants import DISTILABEL_DOCS_URL
+
+# The sitemap can be visited for the full list of pages:
+# SITEMAP_URL: Final[str] = "https://distilabel.argilla.io/latest/sitemap.xml"
+
+
+class DistilabelError:
+ """A mixin class for common functionality shared by all Distilabel-specific errors.
+
+ Attributes:
+ message: A message describing the error.
+ page: An optional error code from PydanticErrorCodes enum.
+
+ Examples:
+ ```python
+ raise DistilabelUserError("This is an error message.")
+ This is an error message.
+
+ raise DistilabelUserError("This is an error message.", page="sections/getting_started/faq/")
+ This is an error message.
+ For further information visit 'https://distilabel.argilla.io/latest/sections/getting_started/faq/'
+ ```
+ """
+
+ def __init__(self, message: str, *, page: Optional[str] = None) -> None:
+ self.message = message
+ self.page = page
+
+ def __str__(self) -> str:
+ if self.page is None:
+ return self.message
+ else:
+ return f"{self.message}\n\nFor further information visit '{DISTILABEL_DOCS_URL}{self.page}'"
+
+
+class DistilabelUserError(DistilabelError, ValueError):
+ """ValueError that we can redirect to a given page in the documentation."""
+
+ pass
+
+
+class DistilabelTypeError(DistilabelError, TypeError):
+ """TypeError that we can redirect to a given page in the documentation."""
+
+ pass
+
+
+class DistilabelNotImplementedError(DistilabelError, NotImplementedError):
+ """NotImplementedError that we can redirect to a given page in the documentation."""
+
+ pass
diff --git a/src/distilabel/exceptions.py b/src/distilabel/exceptions.py
new file mode 100644
index 0000000000..79b4f3cfb5
--- /dev/null
+++ b/src/distilabel/exceptions.py
@@ -0,0 +1,40 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+
+class DistilabelException(Exception):
+ """Base exception (can be gracefully handled) for `distilabel` framework."""
+
+ pass
+
+
+class DistilabelGenerationException(DistilabelException):
+ """Base exception for `LLM` generation errors."""
+
+ pass
+
+
+class DistilabelOfflineBatchGenerationNotFinishedException(
+ DistilabelGenerationException
+):
+ """Exception raised when a batch generation is not finished."""
+
+ jobs_ids: Tuple[str, ...]
+
+ def __init__(self, jobs_ids: Tuple[str, ...]) -> None:
+ self.jobs_ids = jobs_ids
+ super().__init__(f"Batch generation with jobs_ids={jobs_ids} is not finished")
diff --git a/src/distilabel/llms/_dummy.py b/src/distilabel/llms/_dummy.py
new file mode 100644
index 0000000000..740f98cd46
--- /dev/null
+++ b/src/distilabel/llms/_dummy.py
@@ -0,0 +1,70 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, List
+
+from distilabel.llms.base import LLM, AsyncLLM
+from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+
+class DummyAsyncLLM(AsyncLLM):
+ structured_output: Any = None
+
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ async def agenerate( # type: ignore
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+
+
+class DummySyncLLM(LLM):
+ structured_output: Any = None
+
+ def load(self) -> None:
+ super().load()
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate( # type: ignore
+ self, inputs: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return [["output" for _ in range(num_generations)] for _ in range(len(inputs))]
+
+
+class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate(
+ self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any
+ ) -> List["GenerateOutput"]:
+ return [
+ ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs))
+ ]
diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py
index 843b14b21f..f938da58d2 100644
--- a/src/distilabel/llms/anthropic.py
+++ b/src/distilabel/llms/anthropic.py
@@ -75,7 +75,6 @@ class AnthropicLLM(AsyncLLM):
Defaults to `6`.
Examples:
-
Generate text:
```python
@@ -85,11 +84,7 @@ class AnthropicLLM(AsyncLLM):
llm.load()
- # Synchronous request
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
-
- # Asynchronous request
- output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -111,7 +106,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/anyscale.py b/src/distilabel/llms/anyscale.py
index 54b777b8a8..1d4114d383 100644
--- a/src/distilabel/llms/anyscale.py
+++ b/src/distilabel/llms/anyscale.py
@@ -40,7 +40,6 @@ class AnyscaleLLM(OpenAILLM):
It is meant to be used internally.
Examples:
-
Generate text:
```python
@@ -50,7 +49,7 @@ class AnyscaleLLM(OpenAILLM):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
diff --git a/src/distilabel/llms/azure.py b/src/distilabel/llms/azure.py
index 80c0807572..58ed15010f 100644
--- a/src/distilabel/llms/azure.py
+++ b/src/distilabel/llms/azure.py
@@ -48,7 +48,6 @@ class AzureOpenAILLM(OpenAILLM):
`:material-microsoft-azure:`
Examples:
-
Generate text:
```python
@@ -58,11 +57,7 @@ class AzureOpenAILLM(OpenAILLM):
llm.load()
- # Synchrounous request
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
-
- # Asynchronous request
- output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate text from a custom endpoint following the OpenAI API:
@@ -77,11 +72,7 @@ class AzureOpenAILLM(OpenAILLM):
llm.load()
- # Synchronous request
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
-
- # Asynchronous request
- output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -103,7 +94,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py
index 68d82001d2..ced6a8e041 100644
--- a/src/distilabel/llms/base.py
+++ b/src/distilabel/llms/base.py
@@ -16,13 +16,18 @@
import inspect
import json
import logging
+import os
import sys
+import time
from abc import ABC, abstractmethod
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
+from distilabel.constants import SIGINT_HANDLER_CALLED_ENV_NAME
+from distilabel.errors import DistilabelNotImplementedError, DistilabelUserError
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
@@ -66,6 +71,15 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
Attributes:
generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
methods within each `LLM`.
+ use_offline_batch_generation: whether to use the `offline_batch_generate` method to
+ generate the responses.
+ offline_batch_generation_block_until_done: if provided, then polling will be done until
+ the `ofline_batch_generate` method is able to retrieve the results. The value indicate
+ the time to wait between each polling.
+ jobs_ids: the job ids generated by the `offline_batch_generate` method. This attribute
+ is used to store the job ids generated by the `offline_batch_generate` method
+ so later they can be used to retrieve the results. It is not meant to be set by
+ the user.
_logger: the logger to be used for the `LLM`. It will be initialized when the `load`
method is called.
"""
@@ -83,7 +97,19 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
description="The kwargs to be propagated to either `generate` or `agenerate`"
" methods within each `LLM`.",
)
+ use_offline_batch_generation: Optional[RuntimeParameter[bool]] = Field(
+ default=False,
+ description="Whether to use the `offline_batch_generate` method to generate"
+ " the responses.",
+ )
+ offline_batch_generation_block_until_done: Optional[RuntimeParameter[int]] = Field(
+ default=None,
+ description="If provided, then polling will be done until the `ofline_batch_generate`"
+ " method is able to retrieve the results. The value indicate the time to wait between"
+ " each polling.",
+ )
+ jobs_ids: Union[Tuple[str, ...], None] = Field(default=None)
_logger: "Logger" = PrivateAttr(None)
def load(self) -> None:
@@ -137,6 +163,84 @@ def generate(
"""
pass
+ def generate_outputs(
+ self,
+ inputs: List["FormattedInput"],
+ num_generations: int = 1,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ """Generates outputs for the given inputs using either `generate` method or the
+ `offine_batch_generate` method if `use_offline_
+ """
+ if self.use_offline_batch_generation:
+ if self.offline_batch_generation_block_until_done is not None:
+ return self._offline_batch_generate_polling(
+ inputs=inputs,
+ num_generations=num_generations,
+ **kwargs,
+ )
+
+ # This will raise `DistilabelOfflineBatchGenerationNotFinishedException` right away
+ # if the batch generation is not finished.
+ return self.offline_batch_generate(
+ inputs=inputs,
+ num_generations=num_generations,
+ **kwargs,
+ )
+
+ return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
+
+ def _offline_batch_generate_polling(
+ self,
+ inputs: List["FormattedInput"],
+ num_generations: int = 1,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ """Method to poll the `offline_batch_generate` method until the batch generation
+ is finished.
+
+ Args:
+ inputs: the list of inputs to generate responses for.
+ num_generations: the number of generations to generate per input.
+ **kwargs: the additional kwargs to be used for the generation.
+
+ Returns:
+ A list containing the generations for each input.
+ """
+ while True:
+ try:
+ return self.offline_batch_generate(
+ inputs=inputs,
+ num_generations=num_generations,
+ **kwargs,
+ )
+ except DistilabelOfflineBatchGenerationNotFinishedException as e:
+ self._logger.info(
+ f"Waiting for the offline batch generation to finish: {e}. Sleeping"
+ f" for {self.offline_batch_generation_block_until_done} seconds before"
+ " trying to get the results again."
+ )
+ # When running a `Step` in a child process, SIGINT is overriden so the child
+ # process doesn't stop when the parent process receives a SIGINT signal.
+ # The new handler sets an environment variable that is checked here to stop
+ # the polling.
+ if os.getenv(SIGINT_HANDLER_CALLED_ENV_NAME) is not None:
+ self._logger.info(
+ "Received a KeyboardInterrupt. Stopping polling for checking if the"
+ " offline batch generation is finished..."
+ )
+ raise e
+ time.sleep(self.offline_batch_generation_block_until_done) # type: ignore
+ except KeyboardInterrupt as e:
+ # This is for the case the `LLM` is being executed outside a pipeline
+ self._logger.info(
+ "Received a KeyboardInterrupt. Stopping polling for checking if the"
+ " offline batch generation is finished..."
+ )
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids # type: ignore
+ ) from e
+
@property
def generate_parameters(self) -> List["inspect.Parameter"]:
"""Returns the parameters of the `generate` method.
@@ -224,6 +328,7 @@ def get_last_hidden_states(
A list containing the last hidden state for each sequence using a NumPy array
with shape [num_tokens, hidden_size].
"""
+ # TODO: update to use `DistilabelNotImplementedError`
raise NotImplementedError(
f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
)
@@ -242,10 +347,40 @@ def _prepare_structured_output(
Returns:
The structure to be used for the guided generation.
"""
+ # TODO: update to use `DistilabelNotImplementedError`
raise NotImplementedError(
f"Guided generation is not implemented for `{type(self).__name__}`"
)
+ def offline_batch_generate(
+ self,
+ inputs: Union[List["FormattedInput"], None] = None,
+ num_generations: int = 1,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ """Method to generate a list of outputs for the given inputs using an offline batch
+ generation method to be implemented by each `LLM`.
+
+ This method should create jobs the first time is called and store the job ids, so
+ the second and subsequent calls can retrieve the results of the batch generation.
+ If subsequent calls are made before the batch generation is finished, then the method
+ should raise a `DistilabelOfflineBatchGenerationNotFinishedException`. This exception
+ will be handled automatically by the `Pipeline` which will store all the required
+ information for recovering the pipeline execution when the batch generation is finished.
+
+ Args:
+ inputs: the list of inputs to generate responses for.
+ num_generations: the number of generations to generate per input.
+ **kwargs: the additional kwargs to be used for the generation.
+
+ Returns:
+ A list containing the generations for each input.
+ """
+ raise DistilabelNotImplementedError(
+ f"`offline_batch_generate` is not implemented for `{self.__class__.__name__}`",
+ page="sections/how_to_guides/advanced/offline-batch-generation/",
+ )
+
class AsyncLLM(LLM):
"""Abstract class for asynchronous LLMs, so as to benefit from the async capabilities
@@ -400,8 +535,9 @@ def _prepare_structured_output( # type: ignore
schema = structured_output.get("schema")
if not schema:
- raise ValueError(
- f"The `structured_output` argument must contain a schema: {structured_output}"
+ raise DistilabelUserError(
+ f"The `structured_output` argument must contain a schema: {structured_output}",
+ page="sections/how_to_guides/advanced/structured_generation/#instructor",
)
if inspect.isclass(schema) and issubclass(schema, BaseModel):
# We want a json schema for the serialization, but instructor wants a pydantic BaseModel.
@@ -428,7 +564,10 @@ def _prepare_kwargs(
# We can deal with json schema or BaseModel, but we need to convert it to a BaseModel
# for the Instructor client.
schema = structured_output.get("schema", {})
- if not issubclass(schema, BaseModel):
+
+ # If there's already a pydantic model, we don't need to do anything,
+ # otherwise, try to obtain one.
+ if not (inspect.isclass(schema) and issubclass(schema, BaseModel)):
from distilabel.steps.tasks.structured_outputs.utils import (
json_schema_to_model,
)
diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py
index a1295a9ba8..e9d0d0c0f2 100644
--- a/src/distilabel/llms/cohere.py
+++ b/src/distilabel/llms/cohere.py
@@ -70,7 +70,6 @@ class CohereLLM(AsyncLLM):
`"distilabel"`.
Examples:
-
Generate text:
```python
@@ -81,7 +80,7 @@ class CohereLLM(AsyncLLM):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
@@ -102,7 +101,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py
index 3a362951ec..c4c2554329 100644
--- a/src/distilabel/llms/groq.py
+++ b/src/distilabel/llms/groq.py
@@ -63,7 +63,6 @@ class GroqLLM(AsyncLLM):
to `120`.
Examples:
-
Generate text:
```python
@@ -74,7 +73,7 @@ class GroqLLM(AsyncLLM):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
@@ -95,7 +94,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py
index edfc508247..3566228f56 100644
--- a/src/distilabel/llms/huggingface/inference_endpoints.py
+++ b/src/distilabel/llms/huggingface/inference_endpoints.py
@@ -26,6 +26,7 @@
model_validator,
validate_call,
)
+from pydantic._internal._model_construction import ModelMetaclass
from typing_extensions import Annotated, override
from distilabel.llms.base import AsyncLLM
@@ -74,19 +75,18 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
`:hugging:`
Examples:
-
- Free serverless Inference API:
+ Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:
```python
from distilabel.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
- model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Dedicated Inference Endpoints:
@@ -102,7 +102,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Dedicated Inference Endpoints or TGI:
@@ -117,7 +117,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -140,7 +140,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
```
"""
@@ -364,6 +364,12 @@ def _get_structured_output(
"the `structured_output` attribute."
) from e
+ if structured_output:
+ if isinstance(structured_output["value"], ModelMetaclass):
+ structured_output["value"] = structured_output[
+ "value"
+ ].model_json_schema()
+
return structured_output
async def _generate_with_text_generation(
diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py
index d3d16f6ac9..27ab00e5b9 100644
--- a/src/distilabel/llms/huggingface/transformers.py
+++ b/src/distilabel/llms/huggingface/transformers.py
@@ -76,7 +76,6 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
`:hugging:`
Examples:
-
Generate text:
```python
@@ -87,7 +86,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py
index 71a73365bb..48361ef706 100644
--- a/src/distilabel/llms/litellm.py
+++ b/src/distilabel/llms/litellm.py
@@ -41,7 +41,6 @@ class LiteLLM(AsyncLLM):
- `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`.
Examples:
-
Generate text:
```python
@@ -73,7 +72,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py
index f66eb214b0..9d158ea525 100644
--- a/src/distilabel/llms/llamacpp.py
+++ b/src/distilabel/llms/llamacpp.py
@@ -59,7 +59,6 @@ class LlamaCppLLM(LLM):
- [`llama-cpp-python`](https://github.com/abetlen/llama-cpp-python)
Examples:
-
Generate text:
```python
@@ -81,7 +80,7 @@ class LlamaCppLLM(LLM):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -107,7 +106,7 @@ class User(BaseModel):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py
index ed1c3af7d5..a913d6ad0a 100644
--- a/src/distilabel/llms/mistral.py
+++ b/src/distilabel/llms/mistral.py
@@ -26,7 +26,7 @@
)
if TYPE_CHECKING:
- from mistralai.async_client import MistralAsyncClient
+ from mistralai import Mistral
_MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY"
@@ -50,7 +50,7 @@ class MistralLLM(AsyncLLM):
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
_api_key_env_var: the name of the environment variable to use for the API key. It is meant to
be used internally.
- _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally.
+ _aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally.
Set in the `load` method.
Runtime parameters:
@@ -62,7 +62,6 @@ class MistralLLM(AsyncLLM):
Defaults to `64`.
Examples:
-
Generate text:
```python
@@ -94,7 +93,7 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
@@ -126,14 +125,14 @@ class User(BaseModel):
_num_generations_param_supported = False
_api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
- _aclient: Optional["MistralAsyncClient"] = PrivateAttr(...)
+ _aclient: Optional["Mistral"] = PrivateAttr(...)
def load(self) -> None:
- """Loads the `MistralAsyncClient` client to benefit from async requests."""
+ """Loads the `Mistral` client to benefit from async requests."""
super().load()
try:
- from mistralai.async_client import MistralAsyncClient
+ from mistralai import Mistral
except ImportError as ie:
raise ImportError(
"MistralAI Python client is not installed. Please install it using"
@@ -146,7 +145,7 @@ def load(self) -> None:
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
- self._aclient = MistralAsyncClient(
+ self._aclient = Mistral(
api_key=self.api_key.get_secret_value(),
endpoint=self.endpoint,
max_retries=self.max_retries, # type: ignore
@@ -218,7 +217,8 @@ async def agenerate( # type: ignore
# We need to check instructor and see if we can create a PR.
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
else:
- completion = await self._aclient.chat(**kwargs) # type: ignore
+ # completion = await self._aclient.chat(**kwargs) # type: ignore
+ completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
diff --git a/src/distilabel/llms/mixins/cuda_device_placement.py b/src/distilabel/llms/mixins/cuda_device_placement.py
index 5642a10761..c7730940e4 100644
--- a/src/distilabel/llms/mixins/cuda_device_placement.py
+++ b/src/distilabel/llms/mixins/cuda_device_placement.py
@@ -207,11 +207,6 @@ def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> Union[int, None]
return device
return None
- raise RuntimeError(
- "Couldn't find an available CUDA device automatically to be used by the LLM"
- f" '{self._llm_identifier}'. For forcing the use of a specific device, set the"
- " `cuda_devices` attribute to a list with the desired device(s)."
- )
def _set_cuda_visible_devices(self) -> None:
"""Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
diff --git a/src/distilabel/llms/moa.py b/src/distilabel/llms/moa.py
index d139da87e3..a7dd5db19e 100644
--- a/src/distilabel/llms/moa.py
+++ b/src/distilabel/llms/moa.py
@@ -61,7 +61,6 @@ class MixtureOfAgentsLLM(AsyncLLM):
- [Mixture-of-Agents Enhances Large Language Model Capabilities](https://arxiv.org/abs/2406.04692)
Examples:
-
Generate text:
```python
@@ -91,7 +90,7 @@ class MixtureOfAgentsLLM(AsyncLLM):
llm.load()
- output = llm.generate(
+ output = llm.generate_outputs(
inputs=[
[
{
diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py
index bd664b30db..fc3abd605b 100644
--- a/src/distilabel/llms/ollama.py
+++ b/src/distilabel/llms/ollama.py
@@ -79,6 +79,20 @@ class OllamaLLM(AsyncLLM):
Runtime parameters:
- `host`: the Ollama server host.
- `timeout`: the client timeout for the Ollama API. Defaults to `120`.
+
+ Examples:
+ Generate text:
+
+ ```python
+ from distilabel.llms import OllamaLLM
+
+ llm = OllamaLLM(model="llama3")
+
+ llm.load()
+
+ # Call the model
+ output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ ```
"""
model: str
diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py
index 39644e2812..48cac8a50e 100644
--- a/src/distilabel/llms/openai.py
+++ b/src/distilabel/llms/openai.py
@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import io
import os
-from typing import TYPE_CHECKING, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
+import orjson
from pydantic import Field, PrivateAttr, SecretStr, validate_call
+from distilabel import envs
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
if TYPE_CHECKING:
- from openai import AsyncOpenAI
+ from openai import AsyncOpenAI, OpenAI
+ from openai.types import Batch as OpenAIBatch
+ from openai.types import FileObject as OpenAIFileObject
+ from openai.types.chat import ChatCompletion as OpenAIChatCompletion
+ from pydantic import BaseModel
_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
+_OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
class OpenAILLM(AsyncLLM):
@@ -62,7 +71,6 @@ class OpenAILLM(AsyncLLM):
`:simple-openai:`
Examples:
-
Generate text:
```python
@@ -72,7 +80,7 @@ class OpenAILLM(AsyncLLM):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate text from a custom endpoint following the OpenAI API:
@@ -87,7 +95,7 @@ class OpenAILLM(AsyncLLM):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -109,7 +117,24 @@ class User(BaseModel):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ ```
+
+ Generate with Batch API (offline batch generation):
+
+ ```python
+ from distilabel.llms import OpenAILLM
+
+ load = llm = OpenAILLM(
+ model="gpt-3.5-turbo",
+ use_offline_batch_generation=True,
+ offline_batch_generation_block_until_done=5, # poll for results every 5 seconds
+ )
+
+ llm.load()
+
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ # [['Hello! How can I assist you today?']]
```
"""
@@ -141,6 +166,7 @@ class User(BaseModel):
)
_api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME)
+ _client: "OpenAI" = PrivateAttr(None)
_aclient: "AsyncOpenAI" = PrivateAttr(None)
def load(self) -> None:
@@ -148,7 +174,7 @@ def load(self) -> None:
super().load()
try:
- from openai import AsyncOpenAI
+ from openai import AsyncOpenAI, OpenAI
except ImportError as ie:
raise ImportError(
"OpenAI Python client is not installed. Please install it using"
@@ -161,6 +187,13 @@ def load(self) -> None:
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
+ self._client = OpenAI(
+ base_url=self.base_url,
+ api_key=self.api_key.get_secret_value(),
+ max_retries=self.max_retries, # type: ignore
+ timeout=self.timeout,
+ )
+
self._aclient = AsyncOpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
@@ -178,6 +211,15 @@ def load(self) -> None:
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
+ def unload(self) -> None:
+ """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled
+ in case an exception is raised and has to be handled in the main process"""
+
+ self._client = None # type: ignore
+ self._aclient = None # type: ignore
+ self.structured_output = None
+ super().unload()
+
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
@@ -228,11 +270,11 @@ async def agenerate( # type: ignore
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
- structured_output=structured_output,
+ structured_output=structured_output, # type: ignore
client=self._aclient,
framework="openai",
)
- self._aclient = result.get("client")
+ self._aclient = result.get("client") # type: ignore
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
@@ -262,15 +304,41 @@ async def agenerate( # type: ignore
kwargs["response_format"] = response_format
if structured_output:
- kwargs = self._prepare_kwargs(kwargs, structured_output)
+ kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
- generations = []
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
- generations.append(completion.model_dump_json())
- return generations
+ return self._generations_from_structured_output(completion)
+
+ return self._generations_from_openai_completion(completion)
+
+ def _generations_from_structured_output(
+ self, completion: "BaseModel"
+ ) -> "GenerateOutput":
+ """Get the generations from the structured output object.
+
+ Args:
+ completion: an instance of `pydantic.BaseModel` with the content of the structuted
+ output.
+ Returns:
+ A list with the content of the structured output.
+ """
+ return [completion.model_dump_json()]
+
+ def _generations_from_openai_completion(
+ self, completion: "OpenAIChatCompletion"
+ ) -> "GenerateOutput":
+ """Get the generations from the OpenAI Chat Completion object.
+
+ Args:
+ completion: the completion object to get the generations from.
+
+ Returns:
+ A list of strings containing the generated responses for the input.
+ """
+ generations = []
for choice in completion.choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
@@ -279,3 +347,349 @@ async def agenerate( # type: ignore
)
generations.append(content)
return generations
+
+ def offline_batch_generate(
+ self,
+ inputs: Union[List["FormattedInput"], None] = None,
+ num_generations: int = 1,
+ max_new_tokens: int = 128,
+ frequency_penalty: float = 0.0,
+ presence_penalty: float = 0.0,
+ temperature: float = 1.0,
+ top_p: float = 1.0,
+ stop: Optional[Union[str, List[str]]] = None,
+ response_format: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ """Uses the OpenAI batch API to generate `num_generations` responses for the given
+ inputs.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for.
+ num_generations: the number of generations to create per input. Defaults to
+ `1`.
+ max_new_tokens: the maximum number of new tokens that the model will generate.
+ Defaults to `128`.
+ frequency_penalty: the repetition penalty to use for the generation. Defaults
+ to `0.0`.
+ presence_penalty: the presence penalty to use for the generation. Defaults to
+ `0.0`.
+ temperature: the temperature to use for the generation. Defaults to `0.1`.
+ top_p: the top-p value to use for the generation. Defaults to `1.0`.
+ stop: a string or a list of strings to use as a stop sequence for the generation.
+ Defaults to `None`.
+ response_format: the format of the response to return. Must be one of
+ "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
+ for more information on how to use the JSON model from OpenAI. Defaults to `text`.
+
+ Returns:
+ A list of lists of strings containing the generated responses for each input
+ in `inputs`.
+
+ Raises:
+ DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
+ is not finished yet.
+ ValueError: if no job IDs were found to retrieve the results from.
+ """
+ if self.jobs_ids:
+ return self._check_and_get_batch_results()
+
+ if inputs:
+ self.jobs_ids = self._create_jobs(
+ inputs=inputs,
+ **{
+ "model": self.model,
+ "max_tokens": max_new_tokens,
+ "n": num_generations,
+ "frequency_penalty": frequency_penalty,
+ "presence_penalty": presence_penalty,
+ "temperature": temperature,
+ "top_p": top_p,
+ "stop": stop,
+ "response_format": response_format,
+ },
+ )
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids
+ )
+
+ raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
+
+ def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
+ """Checks the status of the batch jobs and retrieves the results from the OpenAI
+ Batch API.
+
+ Returns:
+ A list of lists of strings containing the generated responses for each input.
+
+ Raises:
+ ValueError: if no job IDs were found to retrieve the results from.
+ DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
+ is not finished yet.
+ RuntimeError: if the only batch job found failed.
+ """
+ if not self.jobs_ids:
+ raise ValueError("No job IDs were found to retrieve the results from.")
+
+ outputs = []
+ for batch_id in self.jobs_ids:
+ batch = self._get_openai_batch(batch_id)
+
+ if batch.status in ("validating", "in_progress", "finalizing"):
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids
+ )
+
+ if batch.status in ("failed", "expired", "cancelled", "cancelling"):
+ self._logger.error( # type: ignore
+ f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
+ )
+ if len(self.jobs_ids) == 1:
+ self.jobs_ids = None
+ raise RuntimeError(
+ f"The only OpenAI API Batch that was created with ID '{batch_id}'"
+ f" failed with status '{batch.status}'."
+ )
+
+ continue
+
+ outputs.extend(self._retrieve_batch_results(batch))
+
+ # sort by `custom_id` to return the results in the same order as the inputs
+ outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
+ return [self._parse_output(output) for output in outputs]
+
+ def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
+ """Parses the output from the OpenAI Batch API into a list of strings.
+
+ Args:
+ output: the output to parse.
+
+ Returns:
+ A list of strings containing the generated responses for the input.
+ """
+ from openai.types.chat import ChatCompletion as OpenAIChatCompletion
+
+ if "response" not in output:
+ return []
+
+ if output["response"]["status_code"] != 200:
+ return []
+
+ return self._generations_from_openai_completion(
+ OpenAIChatCompletion(**output["response"]["body"])
+ )
+
+ def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
+ """Gets a batch from the OpenAI Batch API.
+
+ Args:
+ batch_id: the ID of the batch to retrieve.
+
+ Returns:
+ The batch retrieved from the OpenAI Batch API.
+
+ Raises:
+ openai.OpenAIError: if there was an error while retrieving the batch from the
+ OpenAI Batch API.
+ """
+ import openai
+
+ try:
+ return self._client.batches.retrieve(batch_id)
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
+ )
+ raise e
+
+ def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
+ """Retrieves the results of a batch from its output file, parsing the JSONL content
+ into a list of dictionaries.
+
+ Args:
+ batch: the batch to retrieve the results from.
+
+ Returns:
+ A list of dictionaries containing the results of the batch.
+
+ Raises:
+ AssertionError: if no output file ID was found in the batch.
+ """
+ import openai
+
+ assert batch.output_file_id, "No output file ID was found in the batch."
+
+ try:
+ file_response = self._client.files.content(batch.output_file_id)
+ return [orjson.loads(line) for line in file_response.text.splitlines()]
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
+ )
+ return []
+
+ def _create_jobs(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> Tuple[str, ...]:
+ """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ A list of job IDs created in the OpenAI Batch API.
+ """
+ batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
+ jobs = []
+ for batch_input_file in batch_input_files:
+ if batch := self._create_batch_api_job(batch_input_file):
+ jobs.append(batch.id)
+ return tuple(jobs)
+
+ def _create_batch_api_job(
+ self, batch_input_file: "OpenAIFileObject"
+ ) -> Union["OpenAIBatch", None]:
+ """Creates a job in the OpenAI Batch API to generate responses for the given input
+ file.
+
+ Args:
+ batch_input_file: the input file to generate responses for.
+
+ Returns:
+ The batch job created in the OpenAI Batch API.
+ """
+ import openai
+
+ metadata = {"description": "distilabel"}
+
+ if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
+ metadata["distilabel_pipeline_name"] = distilabel_pipeline_name
+
+ if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
+ metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id
+
+ batch = None
+ try:
+ batch = self._client.batches.create(
+ completion_window="24h",
+ endpoint="/v1/chat/completions",
+ input_file_id=batch_input_file.id,
+ metadata=metadata,
+ )
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while creating OpenAI Batch API job for file with ID"
+ f" '{batch_input_file.id}': {e}."
+ )
+ raise e
+ return batch
+
+ def _create_batch_files(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> List["OpenAIFileObject"]:
+ """Creates the necessary input files for the batch API to generate responses. The
+ maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
+ need to split the inputs into multiple files if necessary.
+
+ More information: https://platform.openai.com/docs/api-reference/files/create
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ The list of file objects created for the OpenAI Batch API.
+
+ Raises:
+ openai.OpenAIError: if there was an error while creating the batch input file
+ in the OpenAI Batch API.
+ """
+ import openai
+
+ files = []
+ for file_no, buffer in enumerate(
+ self._create_jsonl_buffers(inputs=inputs, **kwargs)
+ ):
+ try:
+ # TODO: add distilabel pipeline name and id
+ batch_input_file = self._client.files.create(
+ file=(self._name_for_openai_files(file_no), buffer),
+ purpose="batch",
+ )
+ files.append(batch_input_file)
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while creating OpenAI batch input file: {e}"
+ )
+ raise e
+ return files
+
+ def _create_jsonl_buffers(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> Generator[io.BytesIO, None, None]:
+ """Creates a generator of buffers containing the JSONL formatted inputs to be
+ used by the OpenAI Batch API. The buffers created are of size 100MB or less.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ kwargs: the keyword arguments to use for the generation.
+
+ Yields:
+ A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
+ API.
+ """
+ buffer = io.BytesIO()
+ buffer_current_size = 0
+ for i, input in enumerate(inputs):
+ # We create the smallest `custom_id` so we don't increase the size of the file
+ # to much, but we can still sort the results with the order of the inputs.
+ row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
+ row_size = len(row)
+ if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
+ buffer.seek(0)
+ yield buffer
+ buffer = io.BytesIO()
+ buffer_current_size = 0
+ buffer.write(row)
+ buffer_current_size += row_size
+
+ if buffer_current_size > 0:
+ buffer.seek(0)
+ yield buffer
+
+ def _create_jsonl_row(
+ self, input: "FormattedInput", custom_id: str, **kwargs: Any
+ ) -> bytes:
+ """Creates a JSONL formatted row to be used by the OpenAI Batch API.
+
+ Args:
+ input: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ custom_id: a custom ID to use for the row.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ A JSONL formatted row to be used by the OpenAI Batch API.
+ """
+ # TODO: depending on the format of the input, add `response_format` to the kwargs
+ row = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {"messages": input, **kwargs},
+ }
+ json_row = orjson.dumps(row)
+ return json_row + b"\n"
+
+ def _name_for_openai_files(self, file_no: int) -> str:
+ if (
+ envs.DISTILABEL_PIPELINE_NAME is None
+ or envs.DISTILABEL_PIPELINE_CACHE_ID is None
+ ):
+ return f"distilabel-pipeline-fileno-{file_no}.jsonl"
+
+ return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"
diff --git a/src/distilabel/llms/together.py b/src/distilabel/llms/together.py
index aa63ae1ad5..88e7fd7647 100644
--- a/src/distilabel/llms/together.py
+++ b/src/distilabel/llms/together.py
@@ -39,7 +39,6 @@ class TogetherLLM(OpenAILLM):
is meant to be used internally.
Examples:
-
Generate text:
```python
@@ -49,7 +48,7 @@ class TogetherLLM(OpenAILLM):
llm.load()
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py
index f89a7b0912..0c49fa3931 100644
--- a/src/distilabel/llms/vertexai.py
+++ b/src/distilabel/llms/vertexai.py
@@ -43,6 +43,20 @@ class VertexAILLM(AsyncLLM):
Icon:
`:simple-googlecloud:`
+
+ Examples:
+ Generate text:
+
+ ```python
+ from distilabel.llms import VertexAILLM
+
+ llm = VertexAILLM(model="gemini-1.5-pro")
+
+ llm.load()
+
+ # Call the model
+ output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ ```
"""
model: str
diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py
index fc023fd83a..19212755d4 100644
--- a/src/distilabel/llms/vllm.py
+++ b/src/distilabel/llms/vllm.py
@@ -13,6 +13,7 @@
# limitations under the License.
import json
+from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
@@ -43,6 +44,13 @@
from distilabel.steps.tasks.typing import StandardInput
+LogitsProcessorFn = Union[
+ Callable[[List[int], Any], Any],
+ Callable[[List[int], List[int], Any], Any],
+]
+
+LogitsProcessors = List[LogitsProcessorFn]
+
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""`vLLM` library LLM implementation.
@@ -91,7 +99,6 @@ class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
the `LLM` class of `vllm` library.
Examples:
-
Generate text:
```python
@@ -106,7 +113,7 @@ class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
@@ -128,7 +135,7 @@ class User(BaseModel):
llm.load()
# Call the model
- output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
+ output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
@@ -159,7 +166,7 @@ class User(BaseModel):
_model: "_vLLM" = PrivateAttr(None)
_tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
- _logits_processor: Optional[Callable] = PrivateAttr(default=None)
+ _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None)
def load(self) -> None:
"""Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
@@ -197,12 +204,14 @@ def load(self) -> None:
self._tokenizer.chat_template = self.chat_template # type: ignore
if self.structured_output:
- self._logits_processor = self._prepare_structured_output(
+ self._structured_output_logits_processor = self._prepare_structured_output(
self.structured_output
)
def unload(self) -> None:
"""Unloads the `vLLM` model."""
+ self._model = None # type: ignore
+ self._tokenizer = None # type: ignore
CudaDevicePlacementMixin.unload(self)
super().unload()
@@ -283,11 +292,17 @@ def generate( # type: ignore
inputs: List[FormattedInput],
num_generations: int = 1,
max_new_tokens: int = 128,
- frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
+ min_p: float = 0.0,
+ stop: Optional[List[str]] = None,
+ stop_token_ids: Optional[List[int]] = None,
+ include_stop_str_in_output: bool = False,
+ logits_processors: Optional[LogitsProcessors] = None,
extra_sampling_params: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input.
@@ -298,13 +313,24 @@ def generate( # type: ignore
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
- frequency_penalty: the repetition penalty to use for the generation. Defaults
- to `0.0`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
+ frequency_penalty: the repetition penalty to use for the generation. Defaults
+ to `0.0`.
+ repetition_penalty: the repetition penalty to use for the generation Defaults to
+ `1.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
top_k: the top-k value to use for the generation. Defaults to `0`.
+ min_p: the minimum probability to use for the generation. Defaults to `0.0`.
+ stop: a list of strings that will be used to stop the generation when found.
+ Defaults to `None`.
+ stop_token_ids: a list of token ids that will be used to stop the generation
+ when found. Defaults to `None`.
+ include_stop_str_in_output: whether to include the stop string in the output.
+ Defaults to `False`.
+ logits_processors: a list of functions to process the logits before sampling.
+ Defaults to `None`.
extra_sampling_params: dictionary with additional arguments to be passed to
the `SamplingParams` class from `vllm`.
@@ -313,8 +339,12 @@ def generate( # type: ignore
"""
from vllm import SamplingParams
+ if not logits_processors:
+ logits_processors = []
+
if extra_sampling_params is None:
extra_sampling_params = {}
+
structured_output = None
if isinstance(inputs[0], tuple):
@@ -324,25 +354,31 @@ def generate( # type: ignore
prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
sorted_indices = None
- # In case we have a single structured output for the dataset, we can
- logits_processors = None
- if self._logits_processor:
- logits_processors = [self._logits_processor]
+ # Case in which we have a single structured output for the dataset
+ if self._structured_output_logits_processor:
+ logits_processors.append(self._structured_output_logits_processor)
batched_outputs = []
for prepared_inputs, structured_output in prepared_batches:
if structured_output:
- logits_processors = [self._prepare_structured_output(structured_output)]
+ logits_processors.append(
+ self._prepare_structured_output(structured_output)
+ )
sampling_params = SamplingParams( # type: ignore
n=num_generations,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
+ repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
+ min_p=min_p,
max_tokens=max_new_tokens,
+ stop=stop,
+ stop_token_ids=stop_token_ids,
+ include_stop_str_in_output=include_stop_str_in_output,
logits_processors=logits_processors,
**extra_sampling_params,
)
@@ -414,7 +450,6 @@ class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
created to comunicate with the `vLLM` server. Defaults to `None`.
Examples:
-
Generate text:
```python
@@ -427,7 +462,7 @@ class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
llm.load()
- results = llm.generate(
+ results = llm.generate_outputs(
inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
temperature=0.7,
top_p=1.0,
@@ -488,8 +523,8 @@ def load(self) -> None:
self.tokenizer, revision=self.tokenizer_revision
)
- @property
- def model_name(self) -> str:
+ @cached_property
+ def model_name(self) -> str: # type: ignore
"""Returns the name of the model served with vLLM server."""
models = self._client.models.list()
return models.data[0].id
@@ -530,7 +565,7 @@ async def agenerate( # type: ignore
"""Generates `num_generations` responses for each input.
Args:
- inputs: a list of inputs in chat format to generate responses for.
+ input: a single input in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
diff --git a/src/distilabel/mixins/runtime_parameters.py b/src/distilabel/mixins/runtime_parameters.py
index a7dd848f17..f8371e30ab 100644
--- a/src/distilabel/mixins/runtime_parameters.py
+++ b/src/distilabel/mixins/runtime_parameters.py
@@ -13,13 +13,16 @@
# limitations under the License.
import difflib
-import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, TypeVar, Union
from pydantic import BaseModel, Field, PrivateAttr
-from pydantic.types import _SecretField
from typing_extensions import Annotated, get_args, get_origin
+from distilabel.utils.typing_ import (
+ extract_annotation_inner_type,
+ is_type_pydantic_secret_field,
+)
+
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
@@ -73,8 +76,12 @@ def runtime_parameters_names(self) -> "RuntimeParametersNames":
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters[name] = attr.runtime_parameters_names
- # `field: List[RuntiemParametersMixin]`
- if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
+ # `field: List[RuntimeParametersMixin]`
+ if (
+ isinstance(attr, list)
+ and attr
+ and isinstance(attr[0], RuntimeParametersMixin)
+ ):
runtime_parameters[name] = {
str(i): item.runtime_parameters_names for i, item in enumerate(attr)
}
@@ -170,8 +177,8 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
# Handle settings values for `_SecretField`
field_info = self.model_fields[name]
- inner_type = _extract_runtime_parameter_inner_type(field_info.annotation)
- if inspect.isclass(inner_type) and issubclass(inner_type, _SecretField):
+ inner_type = extract_annotation_inner_type(field_info.annotation)
+ if is_type_pydantic_secret_field(inner_type):
value = inner_type(value)
# Set the value of the runtime parameter
@@ -211,22 +218,3 @@ def _is_runtime_parameter(field: "FieldInfo") -> Tuple[bool, bool]:
return True, is_optional
return False, False
-
-
-def _extract_runtime_parameter_inner_type(type_hint: Any) -> Any:
- """Extracts the inner type of a `RuntimeParameter` type hint.
-
- Args:
- type_hint: The type hint to extract the inner type from.
-
- Returns:
- The inner type of the `RuntimeParameter` type hint.
- """
- type_hint_args = get_args(type_hint)
- if get_origin(type_hint) is Annotated:
- return _extract_runtime_parameter_inner_type(type_hint_args[0])
-
- if get_origin(type_hint) is Union and type(None) in type_hint_args:
- return _extract_runtime_parameter_inner_type(type_hint_args[0])
-
- return type_hint
diff --git a/src/distilabel/mixins/signature.py b/src/distilabel/mixins/signature.py
new file mode 100644
index 0000000000..b014f03e90
--- /dev/null
+++ b/src/distilabel/mixins/signature.py
@@ -0,0 +1,83 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+from typing import TYPE_CHECKING, Any, List, Set
+
+from pydantic import BaseModel, Field
+
+from distilabel.utils.serialization import TYPE_INFO_KEY
+
+if TYPE_CHECKING:
+ pass
+
+# Add here the name of the attributes that shouldn't be used to generate the signature.
+# Attributes from a `BaseModel` that is an attribute from the root class must be prefixed
+# with the name of the attribute followed by an underscore. For example, if the attribute
+# `jobs_ids` is an attribute from the `llm` attribute of the root class it should be added
+# as `llm_jobs_ids`.
+_EXCLUDE_FROM_SIGNATURE_DEFAULTS = {
+ TYPE_INFO_KEY,
+ "disable_cuda_device_placement",
+ "input_batch_size",
+ "gpu_memory_utilization",
+ "resources",
+ "exclude_from_signature",
+ "llm_jobs_ids",
+ "llm_offline_batch_generation_block_until_done",
+}
+
+
+class SignatureMixin(BaseModel):
+ """Mixin for creating a signature (for cache) of the class.
+
+ Attributes:
+ exclude_from_signature: list of attributes to exclude from the signature.
+ """
+
+ exclude_from_signature: Set[str] = Field(
+ default=_EXCLUDE_FROM_SIGNATURE_DEFAULTS, exclude=True
+ )
+
+ @property
+ def signature(self) -> str:
+ """Makes a signature (hash) of the class, using its attributes.
+
+ Returns:
+ signature of the class.
+ """
+
+ def flatten_dump(d: Any, parent_key: str = "", sep: str = "_") -> List:
+ items = []
+ for k, v in d.items():
+ new_key = parent_key + sep + k if parent_key else k
+ if isinstance(v, dict):
+ items.extend(flatten_dump(v, new_key, sep=sep))
+ elif isinstance(v, list):
+ if len(v) == 0:
+ items.append((new_key, ""))
+ elif isinstance(v[0], str):
+ items.append((new_key, "-".join(v)))
+ else:
+ for i, x in enumerate(v):
+ items.extend(flatten_dump(x, f"{new_key}-{i}", sep=sep))
+ elif new_key not in self.exclude_from_signature:
+ items.append((new_key, v))
+ return items
+
+ info = []
+ for name, value in flatten_dump(self.dump()):
+ info.append(f"{name}-{str(value)}")
+
+ return hashlib.sha1("-".join(info).encode()).hexdigest()
diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py
index 3253ef864d..5962ecc4f0 100644
--- a/src/distilabel/pipeline/_dag.py
+++ b/src/distilabel/pipeline/_dag.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import base64
import inspect
from collections import defaultdict
from functools import cached_property
@@ -29,12 +29,15 @@
)
import networkx as nx
+import requests
from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
+ RECEIVES_ROUTED_BATCHES_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
+from distilabel.errors import DistilabelUserError
from distilabel.pipeline.routing_batch_function import RoutingBatchFunction
from distilabel.steps.base import GeneratorStep
from distilabel.utils.serialization import (
@@ -47,6 +50,8 @@
from distilabel.mixins.runtime_parameters import RuntimeParametersNames
from distilabel.steps.base import GeneratorStep, Step, _Step
+_MERMAID_URL = "https://mermaid.ink/img/"
+
class DAG(_Serializable):
"""A Directed Acyclic Graph (DAG) to represent the pipeline.
@@ -153,8 +158,9 @@ def add_root_step(self, step: "GeneratorStep") -> None:
Args:
step: The generator step that will be set as the new root.
"""
- self.add_step(step)
- self.add_edge(step.name, next(iter(self)))
+ for other_step, level in self.trophic_levels.items():
+ if level == 1 and other_step != step.name:
+ self.add_edge(step.name, other_step) # type: ignore
@cached_property
def root_steps(self) -> Set[str]:
@@ -174,14 +180,14 @@ def leaf_steps(self) -> Set[str]:
"""
return {node for node, degree in self.G.out_degree() if degree == 0}
- @cached_property
+ @property
def trophic_levels(self) -> Dict[str, int]:
"""The trophic level of each step in the DAG.
Returns:
A dictionary with the trophic level of each step.
"""
- return {step: int(level) for step, level in nx.trophic_levels(self.G).items()}
+ return nx.trophic_levels(self.G)
def get_step_predecessors(self, step_name: str) -> Iterable[str]:
"""Gets the predecessors of a step.
@@ -248,6 +254,21 @@ def is_step_in_trophic_level(self, step_name: str, trophic_level: int) -> bool:
"""
return self.get_step_trophic_level(step_name) == trophic_level
+ def is_convergence_step(self, step_name: str) -> bool:
+ """Checks if a given step is a convegence step.
+
+ Args:
+ step_name: Name of the step to check if a convergence step.
+
+ Returns:
+ True if it is, False otherwise.
+ """
+ predecessors = list(self.get_step_predecessors(step_name))
+ return all(
+ self.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
+ for predecessor in predecessors
+ )
+
def step_in_last_trophic_level(self, step_name: str) -> bool:
"""Checks if a step is in the last trophic level.
@@ -292,13 +313,19 @@ def _get_stage_last_steps(stage_steps: List[str]) -> List[str]:
current_stage = []
stages_last_steps = []
- for step_name in nx.topological_sort(self.G):
+ steps_sorted = list(nx.topological_sort(self.G))
+ for i, step_name in enumerate(steps_sorted):
step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME]
if not step.is_global:
current_stage.append(step_name)
else:
- stages.append(current_stage)
- stages_last_steps.append(_get_stage_last_steps(current_stage))
+ previous_step = None
+ if i > 0:
+ previous_step_name = steps_sorted[i - 1]
+ previous_step = self.get_step(previous_step_name)[STEP_ATTR_NAME]
+ if not previous_step or not previous_step.is_global:
+ stages.append(current_stage)
+ stages_last_steps.append(_get_stage_last_steps(current_stage))
stages.append([step_name])
stages_last_steps.append([step_name])
current_stage = []
@@ -339,9 +366,10 @@ def validate(self) -> None:
# Validate that the steps in the first trophic level are `GeneratorStep`s
if trophic_level == 1:
if not isinstance(step, GeneratorStep):
- raise ValueError(
+ raise DistilabelUserError(
f"Step '{step_name}' cannot be a root step because it is not"
- " a `GeneratorStep`. It should have a previous step in the pipeline."
+ " a `GeneratorStep`. It should have a previous step in the pipeline.",
+ page="sections/how_to_guides/basic/step/#types-of-steps",
)
self._validate_generator_step_process_signature(step)
else:
@@ -374,9 +402,10 @@ def _step_inputs_are_available(self, step: "_Step") -> None:
for output in self.get_step(step_name)[STEP_ATTR_NAME].get_outputs() # type: ignore
]
step_inputs = step.get_inputs()
- if not all(input in inputs_available_for_step for input in step_inputs):
+ required_inputs = [input for input, required in step_inputs.items() if required]
+ if not all(input in inputs_available_for_step for input in required_inputs):
raise ValueError(
- f"Step '{step.name}' requires inputs {step_inputs}, but only the inputs"
+ f"Step '{step.name}' requires inputs {required_inputs}, but only the inputs"
f"={inputs_available_for_step} are available, which means that the inputs"
f"={list(set(step_inputs) - set(inputs_available_for_step))} are missing or not"
" available when the step gets to be executed in the pipeline."
@@ -442,7 +471,7 @@ def _validate_convergence_step(
# Check if the `input_batch_size` of the step is equal or lower than the
for predecessor in predecessors:
- prev_step: "Step" = self.get_step(predecessor)[STEP_ATTR_NAME]
+ prev_step: "Step" = self.get_step(predecessor)[STEP_ATTR_NAME] # type: ignore
if step.input_batch_size > prev_step.input_batch_size: # type: ignore
raise ValueError(
"A convergence step should have an `input_batch_size` equal or lower"
@@ -478,9 +507,10 @@ def _validate_routing_batch_function(
node = self.get_step(predecessor)
routing_batch_function = node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME)
if routing_batch_function is not None and len(predecessors) > 1:
- raise ValueError(
+ raise DistilabelUserError(
f"Step '{step.name}' cannot have multiple predecessors when the batches"
- " of one are being routed with a `routing_batch_function`."
+ " of one are being routed with a `routing_batch_function`.",
+ page="sections/how_to_guides/basic/pipeline/?h=routing#routing-batches-to-specific-downstream-steps",
)
if routing_batch_function is None:
@@ -541,24 +571,27 @@ def _validate_process_step_input_parameter(
if step_input_parameter is None:
if num_predecessors > 1:
prev_steps = ", ".join([f"'{step_name}'" for step_name in predecessors])
- raise ValueError(
+ raise DistilabelUserError(
f"Step '{step_name}' should have a `*args` parameter with type hint"
- f" `StepInput` to receive outputs from previous steps: {prev_steps}."
+ f" `StepInput` to receive outputs from previous steps: {prev_steps}.",
+ page="sections/how_to_guides/basic/step/#define-steps-for-your-pipeline",
)
prev_step_name = next(iter(predecessors))
- raise ValueError(
+ raise DistilabelUserError(
f"Step '{step_name}' should have a parameter with type hint `StepInput`"
- f" to receive the output from the previous step: '{prev_step_name}'."
+ f" to receive the output from the previous step: '{prev_step_name}'.",
+ page="sections/how_to_guides/basic/step/#define-steps-for-your-pipeline",
)
if (
num_predecessors > 1
and step_input_parameter.kind != inspect.Parameter.VAR_POSITIONAL
):
- raise ValueError(
+ raise DistilabelUserError(
f"Step '{step_name}' should have a `*args` parameter with type hint `StepInput`"
- f" to receive outputs from previous steps."
+ f" to receive outputs from previous steps.",
+ page="sections/how_to_guides/basic/step/#define-steps-for-your-pipeline",
)
def _validate_step_process_runtime_parameters( # noqa: C901
@@ -735,3 +768,186 @@ def from_dict(cls, data: Dict[str, Any]) -> "DAG":
)
return dag
+
+ def _get_graph_info_for_draw(
+ self,
+ ) -> Tuple[
+ Set[str],
+ Dict[str, str],
+ List[Dict[str, Any]],
+ Dict[str, Dict[str, Any]],
+ Dict[str, Dict[str, Any]],
+ Dict[str, Dict[str, Any]],
+ ]:
+ """Returns the graph info.
+
+ Returns:
+ all_steps: The set of all steps in the graph.
+ step_name_to_class: The mapping of step names to their classes.
+ connections: The list of connections in the graph.
+ step_outputs: The mapping of step names to their outputs.
+ step_output_mappings: The mapping of step names to their output mappings.
+ step_input_mappings: The mapping of step names to their input mappings.
+ """
+ dump = self.dump()
+ step_name_to_class = {
+ step["step"].get("name"): step["step"].get("type_info", {}).get("name")
+ for step in dump["steps"]
+ }
+ connections = dump["connections"]
+
+ step_outputs = {}
+ for step in dump["steps"]:
+ try:
+ step_outputs[step["name"]] = self.get_step(step["name"])[
+ STEP_ATTR_NAME
+ ].get_outputs()
+ except AttributeError:
+ step_outputs[step["name"]] = {"dynamic": True}
+ step_inputs = {}
+ for step in dump["steps"]:
+ try:
+ step_inputs[step["name"]] = self.get_step(step["name"])[
+ STEP_ATTR_NAME
+ ].get_inputs()
+ except AttributeError:
+ step_inputs[step["name"]] = {"dynamic": True}
+
+ # Add Argilla and Distiset steps to the graph
+ leaf_steps = self.leaf_steps
+ for idx, leaf_step in enumerate(leaf_steps):
+ if "to_argilla" in leaf_step:
+ connections.append({"from": leaf_step, "to": [f"to_argilla_{idx}"]})
+ step_name_to_class[f"to_argilla_{idx}"] = "Argilla"
+ step_outputs[leaf_step] = {"records": True}
+ else:
+ connections.append({"from": leaf_step, "to": [f"distiset_{idx}"]})
+ step_name_to_class[f"distiset_{idx}"] = "Distiset"
+
+ # Create a set of all steps in the graph
+ all_steps = {con["from"] for con in connections} | {
+ to_step for con in connections for to_step in con["to"]
+ }
+
+ # Create a mapping of step outputs
+ step_output_mappings = {
+ step["name"]: {
+ k: v
+ for k, v in {
+ **{output: output for output in step_outputs[step["name"]]},
+ **step["step"]["output_mappings"],
+ }.items()
+ if list(
+ dict(
+ {
+ **{output: output for output in step_outputs[step["name"]]},
+ **step["step"]["output_mappings"],
+ }.items()
+ ).values()
+ ).count(v)
+ == 1
+ or k != v
+ }
+ for step in dump["steps"]
+ }
+ step_input_mappings = {
+ step["name"]: dict(
+ {
+ **{input: input for input in step_inputs[step["name"]]},
+ **step["step"]["input_mappings"],
+ }.items()
+ )
+ for step in dump["steps"]
+ }
+
+ return (
+ all_steps,
+ step_name_to_class,
+ connections,
+ step_outputs,
+ step_output_mappings,
+ step_input_mappings,
+ )
+
+ def draw(self, top_to_bottom: bool = False, show_edge_labels: bool = True) -> str: # noqa: C901
+ """Draws the DAG and returns the image content.
+
+ Parameters:
+ top_to_bottom: Whether to draw the DAG top to bottom. Defaults to `False`.
+ show_edge_labels: Whether to show the edge labels. Defaults to `True`.
+
+ Returns:
+ The image content.
+ """
+ (
+ all_steps,
+ step_name_to_class,
+ connections,
+ step_outputs,
+ step_output_mappings,
+ step_input_mappings,
+ ) = self._get_graph_info_for_draw()
+ graph = [f"flowchart {'TD' if top_to_bottom else 'LR'}"]
+ for step in all_steps:
+ graph.append(f' {step}["{step_name_to_class[step]}"]')
+
+ if show_edge_labels:
+ for connection in connections:
+ from_step = connection["from"]
+ from_mapping = step_output_mappings[from_step]
+ for to_step in connection["to"]:
+ for from_column in set(
+ list(step_outputs[from_step].keys())
+ + list(step_output_mappings[from_step].keys())
+ ):
+ if from_column not in from_mapping:
+ continue
+ to_column = from_mapping.get(from_column)
+
+ # walk through mappings
+ to_mapping = step_input_mappings.get(to_step, {})
+ edge_label = [from_column]
+ if from_column != to_column:
+ edge_label.append(to_column)
+ if edge_label[-1] in to_mapping:
+ edge_label.append(to_mapping[edge_label[-1]])
+
+ if (
+ edge_label[-1] not in to_mapping
+ and from_step not in self.leaf_steps
+ ):
+ edge_label.append("**_pass_**")
+ edge_label = ":".join(list(dict.fromkeys(edge_label)))
+ graph.append(f" {from_step} --> |{edge_label}| {to_step}")
+
+ else:
+ for connection in connections:
+ from_step = connection["from"]
+ for to_step in connection["to"]:
+ graph.append(f" {from_step} --> {to_step}")
+
+ graph.append("classDef component text-align:center;")
+ graph_styled = "\n".join(graph)
+ return _to_mermaid_image(graph_styled)
+
+
+def _to_mermaid_image(graph_styled: str) -> str:
+ """Converts a Mermaid graph to an image using the Mermaid Ink service.
+
+ Parameters:
+ graph_styled: The Mermaid graph to convert to an image.
+
+ Returns:
+ The image content.
+ """
+ base64_string = base64.b64encode(graph_styled.encode("ascii")).decode("ascii")
+ url = f"{_MERMAID_URL}{base64_string}?type=png"
+
+ try:
+ response = requests.get(url, timeout=10)
+ response.raise_for_status()
+ return response.content
+ except requests.RequestException as e:
+ raise ValueError(
+ "Error accessing https://mermaid.ink/. See stacktrace for details."
+ ) from e
diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py
index 2cf65d2b38..c3397c392c 100644
--- a/src/distilabel/pipeline/base.py
+++ b/src/distilabel/pipeline/base.py
@@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import hashlib
import logging
import os
+import shutil
import signal
import threading
import time
-import uuid
from abc import ABC, abstractmethod
+from inspect import isclass
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -35,19 +35,13 @@
)
import fsspec
+from pydantic import BaseModel
from typing_extensions import Self
from upath import UPath
-from distilabel import __version__
-from distilabel.constants import (
- CONVERGENCE_STEP_ATTR_NAME,
- INPUT_QUEUE_ATTR_NAME,
- LAST_BATCH_SENT_FLAG,
- RECEIVES_ROUTED_BATCHES_ATTR_NAME,
- ROUTING_BATCH_FUNCTION_ATTR_NAME,
- STEP_ATTR_NAME,
-)
+from distilabel import __version__, constants, envs
from distilabel.distiset import create_distiset
+from distilabel.errors import DistilabelUserError
from distilabel.mixins.requirements import RequirementsMixin
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
@@ -56,16 +50,22 @@
from distilabel.steps.base import GeneratorStep
from distilabel.steps.generators.utils import make_generator_step
from distilabel.utils.logging import setup_logging, stop_logging
+from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import (
- TYPE_INFO_KEY,
_Serializable,
read_json,
)
+from distilabel.utils.typing_ import (
+ extract_annotation_inner_type,
+ is_type_pydantic_secret_field,
+)
if TYPE_CHECKING:
from os import PathLike
from queue import Queue
+ from pydantic import BaseModel
+
from distilabel.distiset import Distiset
from distilabel.pipeline.routing_batch_function import RoutingBatchFunction
from distilabel.pipeline.typing import (
@@ -87,15 +87,13 @@ class _CacheLocation(TypedDict):
pipeline: Path
batch_manager: Path
+ steps_data: Path
data: Path
batch_input_data: Path
log_file: Path
stages_file: Path
-BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel" / "pipelines"
-
-
class _GlobalPipelineManager:
"""Class to manage the global pipeline instance that will be used by the steps when
created within a pipeline context.
@@ -128,6 +126,8 @@ def get_pipeline(cls) -> Union["BasePipeline", None]:
_STEP_LOAD_FAILED_CODE = -666
_STEP_NOT_LOADED_CODE = -999
+_PIPELINE_DEFAULT_NAME = "__default_pipeline_name__"
+
class BasePipeline(ABC, RequirementsMixin, _Serializable):
"""Base class for a `distilabel` pipeline.
@@ -185,17 +185,17 @@ def __init__(
Defaults to `None`, but can be helpful to inform in a pipeline to be shared
that this requirements must be installed.
"""
- self.name = name or f"pipeline_{str(uuid.uuid4())[:8]}"
+ self.name = name or _PIPELINE_DEFAULT_NAME
self.description = description
self._enable_metadata = enable_metadata
self.dag = DAG()
if cache_dir:
self._cache_dir = Path(cache_dir)
- elif env_cache_dir := os.getenv("DISTILABEL_CACHE_DIR"):
+ elif env_cache_dir := envs.DISTILABEL_CACHE_DIR:
self._cache_dir = Path(env_cache_dir)
else:
- self._cache_dir = BASE_CACHE_DIR
+ self._cache_dir = constants.PIPELINES_CACHE_DIR
self._logger = logging.getLogger("distilabel.pipeline")
@@ -209,6 +209,10 @@ def __init__(
self._stop_called_lock = threading.Lock()
self._stop_calls = 0
+ self._recover_offline_batch_generate_for_step: Union[
+ Tuple[str, List[List[Dict[str, Any]]]], None
+ ] = None
+
self._fs: Optional[fsspec.AbstractFileSystem] = None
self._storage_base_path: Optional[str] = None
self._use_fs_to_pass_data: bool = False
@@ -231,49 +235,25 @@ def __enter__(self) -> Self:
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Unset the global pipeline instance when exiting a pipeline context."""
_GlobalPipelineManager.set_pipeline(None)
+ self._set_pipeline_name()
+
+ def _set_pipeline_name(self) -> None:
+ """Creates a name for the pipeline if it's the default one (if hasn't been set)."""
+ if self.name == _PIPELINE_DEFAULT_NAME:
+ self.name = f"pipeline_{'_'.join(self.dag)}"
- def _create_signature(self) -> str:
+ @property
+ def signature(self) -> str:
"""Makes a signature (hash) of a pipeline, using the step ids and the adjacency between them.
The main use is to find the pipeline in the cache folder.
Returns:
- int: Signature of the pipeline.
+ Signature of the pipeline.
"""
- hasher = hashlib.sha1()
- steps_info = []
pipeline_dump = self.dump()["pipeline"]
-
- for step in pipeline_dump["steps"]:
- step_info = step["name"]
- for argument, value in sorted(step[STEP_ATTR_NAME].items()):
- if (argument == TYPE_INFO_KEY) or (value is None):
- continue
-
- if isinstance(value, dict):
- # input_mappings/output_mappings
- step_info += "-".join(
- [
- f"{str(k)}={str(v)}"
- for k, v in value.items()
- if k not in ("disable_cuda_device_placement",)
- ]
- )
- elif isinstance(value, (list, tuple)):
- # runtime_parameters_info
- step_info += "-".join([str(v) for v in value])
- elif isinstance(value, (int, str, float, bool)):
- if argument != "disable_cuda_device_placement":
- # batch_size/name
- step_info += str(value)
- else:
- raise ValueError(
- f"Field '{argument}' in step '{step['name']}' has type {type(value)}, explicitly cast the type to 'str'."
- )
-
- steps_info.append(step_info)
-
+ steps_names = list(self.dag)
connections_info = [
f"{c['from']}-{'-'.join(c['to'])}" for c in pipeline_dump["connections"]
]
@@ -282,18 +262,17 @@ def _create_signature(self) -> str:
for function in pipeline_dump["routing_batch_functions"]:
step = function["step"]
routing_batch_function: "RoutingBatchFunction" = self.dag.get_step(step)[
- ROUTING_BATCH_FUNCTION_ATTR_NAME
+ constants.ROUTING_BATCH_FUNCTION_ATTR_NAME
]
if type_info := routing_batch_function._get_type_info():
step += f"-{type_info}"
+ routing_batch_functions_info.append(step)
- hasher.update(
+ return hashlib.sha1(
",".join(
- steps_info + connections_info + routing_batch_functions_info
+ steps_names + connections_info + routing_batch_functions_info
).encode()
- )
-
- return hasher.hexdigest()
+ ).hexdigest()
def run(
self,
@@ -302,6 +281,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ logging_handlers: Optional[List[logging.Handler]] = None,
) -> "Distiset": # type: ignore
"""Run the pipeline. It will set the runtime parameters for the steps and validate
the pipeline.
@@ -328,6 +308,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ logging_handlers: A list of logging handlers that will be used to log the
+ output of the pipeline. This argument can be useful so the logging messages
+ can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
@@ -340,17 +323,30 @@ def run(
# cache when the pipeline is run, so it's important to do it first.
self._set_runtime_parameters(parameters or {})
- setup_logging(
- log_queue=self._log_queue, filename=str(self._cache_location["log_file"])
- )
+ self._refresh_pipeline_from_cache()
if dataset is not None:
self._add_dataset_generator_step(dataset)
+ setup_logging(
+ log_queue=self._log_queue,
+ filename=str(self._cache_location["log_file"]),
+ logging_handlers=logging_handlers,
+ )
+
+ # Set the name of the pipeline if it's the default one. This should be called
+ # if the pipeline is defined within the context manager, and the run is called
+ # outside of it. Is here in the following case:
+ # with Pipeline() as pipeline:
+ # pipeline.run()
+ self._set_pipeline_name()
+
# Validate the pipeline DAG to check that all the steps are chainable, there are
# no missing runtime parameters, batch sizes are correct, etc.
self.dag.validate()
+ self._set_pipeline_artifacts_path_in_steps()
+
# Set the initial load status for all the steps
self._init_steps_load_status()
@@ -360,12 +356,8 @@ def run(
# Load the `_BatchManager` from cache or create one from scratch
self._load_batch_manager(use_cache)
- if to_install := self.requirements_to_install():
- # Print the list of requirements like they would appear in a requirements.txt
- to_install_list = "\n" + "\n".join(to_install)
- msg = f"Please install the following requirements to run the pipeline: {to_install_list}"
- self._logger.error(msg)
- raise ModuleNotFoundError(msg)
+ # Check pipeline requirements are installed
+ self._check_requirements()
# Setup the filesystem that will be used to pass the data of the `_Batch`es
self._setup_fsspec(storage_parameters)
@@ -383,7 +375,7 @@ def run(
" Returning `Distiset` from cache data..."
)
distiset = create_distiset(
- self._cache_location["data"],
+ data_dir=self._cache_location["data"],
pipeline_path=self._cache_location["pipeline"],
log_filename_path=self._cache_location["log_file"],
enable_metadata=self._enable_metadata,
@@ -392,7 +384,7 @@ def run(
stop_logging()
return distiset
- self._setup_write_buffer()
+ self._setup_write_buffer(use_cache)
self._print_load_stages_info()
@@ -400,6 +392,7 @@ def dry_run(
self,
parameters: Optional[Dict[str, Dict[str, Any]]] = None,
batch_size: int = 1,
+ dataset: Optional["InputDataset"] = None,
) -> "Distiset":
"""Do a dry run to test the pipeline runs as expected.
@@ -412,6 +405,9 @@ def dry_run(
the runtime parameters for the step as the value. Defaults to `None`.
batch_size: The batch size of the unique batch generated by the generators
steps of the pipeline. Defaults to `1`.
+ dataset: If given, it will be used to create a `GeneratorStep` and put it as the
+ root step. Convenient method when you have already processed the dataset in
+ your script and just want to pass it already processed. Defaults to `None`.
Returns:
Will return the `Distiset` as the main run method would do.
@@ -419,14 +415,14 @@ def dry_run(
self._dry_run = True
for step_name in self.dag:
- step = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ step = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
if step.is_generator:
if not parameters:
parameters = {}
parameters[step_name] = {"batch_size": batch_size}
- distiset = self.run(parameters=parameters, use_cache=False)
+ distiset = self.run(parameters=parameters, use_cache=False, dataset=dataset)
self._dry_run = False
return distiset
@@ -443,13 +439,15 @@ def _add_dataset_generator_step(self, dataset: "InputDataset") -> None:
ValueError: If there's already a `GeneratorStep` in the pipeline.
"""
for step_name in self.dag:
- step = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ step = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
if isinstance(step_name, GeneratorStep):
- raise ValueError(
- "There is already a `GeneratorStep` in the pipeline, you can either pass a `dataset` to the "
- f"run method, or create a `GeneratorStep` explictly. `GeneratorStep`: {step}"
+ raise DistilabelUserError(
+ "There is already a `GeneratorStep` in the pipeline, you can either"
+ " pass a `dataset` to the run method, or create a `GeneratorStep` explictly."
+ f" `GeneratorStep`: {step}",
+ page="sections/how_to_guides/basic/step/#types-of-steps",
)
- loader = make_generator_step(dataset)
+ loader = make_generator_step(dataset, self)
self.dag.add_root_step(loader)
def get_runtime_parameters_info(self) -> "PipelineRuntimeParametersInfo":
@@ -461,7 +459,7 @@ def get_runtime_parameters_info(self) -> "PipelineRuntimeParametersInfo":
"""
runtime_parameters = {}
for step_name in self.dag:
- step: "_Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
runtime_parameters[step_name] = step.get_runtime_parameters_info()
return runtime_parameters
@@ -471,6 +469,27 @@ def _init_steps_load_status(self) -> None:
for step_name in self.dag:
self._steps_load_status[step_name] = _STEP_NOT_LOADED_CODE
+ def _set_pipeline_artifacts_path_in_steps(self) -> None:
+ """Sets the attribute `_pipeline_artifacts_path` in all the `Step`s of the pipeline,
+ so steps can use it to get the path to save the generated artifacts."""
+ artifacts_path = self._cache_location["data"] / constants.STEPS_ARTIFACTS_PATH
+ for name in self.dag:
+ step: "_Step" = self.dag.get_step(name)[constants.STEP_ATTR_NAME]
+ step.set_pipeline_artifacts_path(path=artifacts_path)
+
+ def _check_requirements(self) -> None:
+ """Checks if the dependencies required to run the pipeline are installed.
+
+ Raises:
+ ModuleNotFoundError: if one or more requirements are missing.
+ """
+ if to_install := self.requirements_to_install():
+ # Print the list of requirements like they would appear in a requirements.txt
+ to_install_list = "\n" + "\n".join(to_install)
+ msg = f"Please install the following requirements to run the pipeline: {to_install_list}"
+ self._logger.error(msg)
+ raise ModuleNotFoundError(msg)
+
def _setup_fsspec(
self, storage_parameters: Optional[Dict[str, Any]] = None
) -> None:
@@ -494,9 +513,10 @@ def _setup_fsspec(
return
if "path" not in storage_parameters:
- raise ValueError(
+ raise DistilabelUserError(
"The 'path' key must be present in the `storage_parameters` dictionary"
- " if it's not `None`."
+ " if it's not `None`.",
+ page="sections/how_to_guides/advanced/fs_to_pass_data/",
)
path = storage_parameters.pop("path")
@@ -525,10 +545,12 @@ def _add_edge(self, from_step: str, to_step: str) -> None:
# Check if `from_step` has a `routing_batch_function`. If it does, then mark
# `to_step` as a step that will receive a routed batch.
node = self.dag.get_step(from_step) # type: ignore
- routing_batch_function = node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME, None)
+ routing_batch_function = node.get(
+ constants.ROUTING_BATCH_FUNCTION_ATTR_NAME, None
+ )
self.dag.set_step_attr(
name=to_step,
- attr=RECEIVES_ROUTED_BATCHES_ATTR_NAME,
+ attr=constants.RECEIVES_ROUTED_BATCHES_ATTR_NAME,
value=routing_batch_function is not None,
)
@@ -538,7 +560,7 @@ def _is_convergence_step(self, step_name: str) -> None:
Args:
step_name: The name of the step.
"""
- return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME)
+ return self.dag.get_step(step_name).get(constants.CONVERGENCE_STEP_ATTR_NAME)
def _add_routing_batch_function(
self, step_name: str, routing_batch_function: "RoutingBatchFunction"
@@ -551,7 +573,7 @@ def _add_routing_batch_function(
"""
self.dag.set_step_attr(
name=step_name,
- attr=ROUTING_BATCH_FUNCTION_ATTR_NAME,
+ attr=constants.ROUTING_BATCH_FUNCTION_ATTR_NAME,
value=routing_batch_function,
)
@@ -570,7 +592,7 @@ def _set_runtime_parameters(self, parameters: Dict[str, Dict[str, Any]]) -> None
f" Available steps are: {step_names}."
)
else:
- step: "_Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
step.set_runtime_parameters(step_parameters)
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
@@ -585,6 +607,45 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""
return self.dag.dump()
+ def draw(
+ self,
+ path: Optional[Union[str, Path]] = "pipeline.png",
+ top_to_bottom: bool = False,
+ show_edge_labels: bool = True,
+ ) -> str:
+ """
+ Draws the pipeline.
+
+ Parameters:
+ path: The path to save the image to.
+ top_to_bottom: Whether to draw the DAG top to bottom. Defaults to `False`.
+ show_edge_labels: Whether to show the edge labels. Defaults to `True`.
+
+ Returns:
+ The path to the saved image.
+ """
+ png = self.dag.draw(
+ top_to_bottom=top_to_bottom, show_edge_labels=show_edge_labels
+ )
+ with open(path, "wb") as f:
+ f.write(png)
+ return path
+
+ def __repr__(self) -> str:
+ """
+ If running in a Jupyter notebook, display an image representing this `Pipeline`.
+ """
+ if in_notebook():
+ try:
+ from IPython.display import Image, display
+
+ image_data = self.dag.draw()
+
+ display(Image(image_data))
+ except Exception:
+ pass
+ return super().__repr__()
+
def dump(self, **kwargs: Any) -> Dict[str, Any]:
return {
"distilabel": {"version": __version__},
@@ -624,16 +685,32 @@ def _cache_location(self) -> "_CacheLocation":
Returns:
Path: Filenames where the pipeline content will be serialized.
"""
- folder = self._cache_dir / self.name / self._create_signature()
+ folder = self._cache_dir / self.name / self.signature
+ pipeline_execution_dir = folder / "executions" / self.aggregated_steps_signature
return {
- "pipeline": folder / "pipeline.yaml",
- "batch_manager": folder / "batch_manager.json",
- "data": folder / "data",
- "batch_input_data": folder / "batch_input_data",
- "log_file": folder / "pipeline.log",
- "stages_file": folder / "stages.json",
+ "pipeline": pipeline_execution_dir / "pipeline.yaml",
+ "batch_manager": pipeline_execution_dir / "batch_manager.json",
+ "steps_data": self._cache_dir / self.name / "steps_data",
+ "data": pipeline_execution_dir / "data",
+ "batch_input_data": pipeline_execution_dir / "batch_input_data",
+ "log_file": pipeline_execution_dir / "pipeline.log",
+ "stages_file": pipeline_execution_dir / "stages.json",
}
+ @property
+ def aggregated_steps_signature(self) -> str:
+ """Creates an aggregated signature using `Step`s signature that will be used for
+ the `_BatchManager`.
+
+ Returns:
+ The aggregated signature.
+ """
+ signatures = []
+ for step_name in self.dag:
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
+ signatures.append(step.signature)
+ return hashlib.sha1("".join(signatures).encode()).hexdigest()
+
def _cache(self) -> None:
"""Saves the `BasePipeline` using the `_cache_filename`."""
if self._dry_run:
@@ -645,7 +722,10 @@ def _cache(self) -> None:
)
if self._batch_manager is not None:
- self._batch_manager.cache(self._cache_location["batch_manager"])
+ self._batch_manager.cache(
+ path=self._cache_location["batch_manager"],
+ steps_data_path=self._cache_location["steps_data"],
+ )
self._save_stages_status()
@@ -675,33 +755,137 @@ def _load_stages_status(self, use_cache: bool = True) -> None:
[] for _ in range(len(self.dag.get_steps_load_stages()[0]))
]
+ def _refresh_pipeline_from_cache(self) -> None:
+ """Refresh the DAG (and its steps) from the cache file. This is useful as some
+ `Step`s can update and change their state during the pipeline execution, and this
+ method will make sure the pipeline is up-to-date with the latest changes when
+ the pipeline is reloaded from cache.
+ """
+
+ def recursively_handle_secrets_and_excluded_attributes(
+ cached_model: "BaseModel", model: "BaseModel"
+ ) -> None:
+ """Recursively handle the secrets and excluded attributes of a `BaseModel`,
+ setting the values of the cached model to the values of the model.
+
+ Args:
+ cached_model: The cached model that will be updated as it doesn't contain
+ the secrets and excluded attributes (not serialized).
+ model: The model that contains the secrets and excluded attributes because
+ it comes from pipeline instantiation.
+ """
+ for field_name, field_info in cached_model.model_fields.items():
+ if field_name in ("pipeline"):
+ continue
+
+ inner_type = extract_annotation_inner_type(field_info.annotation)
+ if is_type_pydantic_secret_field(inner_type) or field_info.exclude:
+ setattr(cached_model, field_name, getattr(model, field_name))
+ elif isclass(inner_type) and issubclass(inner_type, BaseModel):
+ recursively_handle_secrets_and_excluded_attributes(
+ getattr(cached_model, field_name),
+ getattr(model, field_name),
+ )
+
+ if self._cache_location["pipeline"].exists():
+ cached_dag = self.from_yaml(self._cache_location["pipeline"]).dag
+
+ for step_name in cached_dag:
+ step_cached: "_Step" = cached_dag.get_step(step_name)[
+ constants.STEP_ATTR_NAME
+ ]
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
+ recursively_handle_secrets_and_excluded_attributes(step_cached, step)
+
+ self.dag = cached_dag
+
def _load_batch_manager(self, use_cache: bool = True) -> None:
"""Will try to load the `_BatchManager` from the cache dir if found. Otherwise,
it will create one from scratch.
+
+ If the `_BatchManager` is loaded from cache, we check for invalid steps (those that
+ may have a different signature than the original in the pipeline folder), and
+ restart them, as well as their successors.
+
+ Args:
+ use_cache: whether the cache should be used or not.
"""
batch_manager_cache_loc = self._cache_location["batch_manager"]
+
+ # This first condition handles the case in which the pipeline is exactly the same
+ # no steps have been added, removed or changed.
if use_cache and batch_manager_cache_loc.exists():
self._logger.info(
f"💾 Loading `_BatchManager` from cache: '{batch_manager_cache_loc}'"
)
- self._batch_manager = _BatchManager.load_from_cache(batch_manager_cache_loc)
+ self._batch_manager = _BatchManager.load_from_cache(
+ dag=self.dag,
+ batch_manager_path=batch_manager_cache_loc,
+ steps_data_path=self._cache_location["steps_data"],
+ )
+ self._invalidate_steps_cache_if_required()
+ # In this other case, the pipeline has been changed. We need to create a new batch
+ # manager and if `use_cache==True` then check which outputs have we computed and
+ # cached for steps that haven't changed but that were executed in another pipeline,
+ # and therefore we can reuse
else:
- self._batch_manager = _BatchManager.from_dag(self.dag)
+ self._batch_manager = _BatchManager.from_dag(
+ dag=self.dag,
+ use_cache=use_cache,
+ steps_data_path=self._cache_location["steps_data"],
+ )
+
+ def _invalidate_steps_cache_if_required(self) -> None:
+ """Iterates over the steps of the pipeline and invalidates their cache if required."""
+ for step_name in self.dag:
+ # `GeneratorStep`s doesn't receive input data so no need to check their
+ # `_BatchManagerStep`
+ if self.dag.get_step(step_name)[constants.STEP_ATTR_NAME].is_generator:
+ continue
+
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
+ if not step.use_cache:
+ self._batch_manager.invalidate_cache_for(
+ step_name=step.name,
+ dag=self.dag,
+ steps_data_path=self._cache_location["steps_data"],
+ ) # type: ignore
+ self._logger.info(
+ f"♻️ Step '{step.name}' won't use cache (`use_cache=False`). The cache of this step and their successors won't be"
+ " reused and the results will have to be recomputed."
+ )
+ break
- def _setup_write_buffer(self) -> None:
+ def _setup_write_buffer(self, use_cache: bool = True) -> None:
"""Setups the `_WriteBuffer` that will store the data of the leaf steps of the
pipeline while running, so the `Distiset` can be created at the end.
"""
- buffer_data_path = self._cache_location["data"]
+ if not use_cache and self._cache_location["data"].exists():
+ shutil.rmtree(self._cache_location["data"])
+ buffer_data_path = self._cache_location["data"] / constants.STEPS_OUTPUTS_PATH
self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'")
- self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps)
+ self._write_buffer = _WriteBuffer(
+ buffer_data_path,
+ self.dag.leaf_steps,
+ steps_cached={
+ step_name: self.dag.get_step(step_name)[
+ constants.STEP_ATTR_NAME
+ ].use_cache
+ for step_name in self.dag
+ },
+ )
def _print_load_stages_info(self) -> None:
"""Prints the information about the load stages."""
stages, _ = self.dag.get_steps_load_stages()
msg = ""
for stage, steps in enumerate(stages):
- msg += f"\n * Stage {stage}: {steps}"
+ steps_to_be_loaded = self._steps_to_be_loaded_in_stage(stage)
+ msg += f"\n * Stage {stage}:"
+ for step in steps:
+ msg += f"\n - '{step}'"
+ if step not in steps_to_be_loaded:
+ msg += " (results cached, won't be loaded and executed)"
self._logger.info(
f"⌛ The steps of the pipeline will be loaded in stages:{msg}"
)
@@ -737,9 +921,10 @@ def _output_queue_loop(self) -> None:
# we need to handle the stop of the pipeline and break the loop to avoid
# propagating the batches through the pipeline and making the stop process
# slower.
- if self._stop_called:
- self._handle_batch_on_stop(batch)
- break
+ with self._stop_called_lock:
+ if self._stop_called:
+ self._handle_batch_on_stop(batch)
+ break
# If there is another load stage and all the `last_batch`es from the stage
# have been received, then load the next stage.
@@ -776,9 +961,12 @@ def _should_continue_processing(self) -> bool:
`True` if should continue consuming batches, `False` otherwise and the pipeline
should stop.
"""
- return self._batch_manager.can_generate() and not self._stop_called # type: ignore
+ with self._stop_called_lock:
+ return self._batch_manager.can_generate() and not self._stop_called # type: ignore
- def _process_batch(self, batch: "_Batch") -> None:
+ def _process_batch(
+ self, batch: "_Batch", send_last_batch_flag: bool = True
+ ) -> None:
"""Process a batch consumed from the `output_queue`.
Args:
@@ -794,18 +982,65 @@ def _process_batch(self, batch: "_Batch") -> None:
self._write_buffer.add_batch(batch) # type: ignore
if batch.last_batch:
- _, stages_last_steps = self.dag.get_steps_load_stages()
- stage_last_steps = stages_last_steps[self._current_stage]
- if batch.step_name in stage_last_steps:
- self._stages_last_batch[self._current_stage].append(batch.step_name)
- self._stages_last_batch[self._current_stage].sort()
+ self._register_stages_last_batch(batch)
# Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the step
# if the batch is the last one, so they stop their processing loop even if they
# haven't received the last batch because of the routing function.
- for step_name in self.dag.get_step_predecessors(batch.step_name):
- if self._is_step_running(step_name):
- self._send_last_batch_flag_to_step(step_name)
+ if send_last_batch_flag:
+ for step_name in self.dag.get_step_predecessors(batch.step_name):
+ if self._is_step_running(step_name):
+ self._send_last_batch_flag_to_step(step_name)
+
+ def _set_step_for_recovering_offline_batch_generation(
+ self, step: "_Step", data: List[List[Dict[str, Any]]]
+ ) -> None:
+ """Sets the required information to recover a pipeline execution from a `_Step`
+ that used an `LLM` with offline batch generation.
+
+ Args:
+ step: The `_Step` that used an `LLM` with offline batch generation.
+ data: The data that was used to generate the batches for the step.
+ """
+ # Replace step so the attribute `jobs_ids` of the `LLM` is not lost, as it was
+ # updated in the child process but not in the main process.
+ step_name: str = step.name # type: ignore
+ self.dag.set_step_attr(
+ name=step_name, attr=constants.STEP_ATTR_NAME, value=step
+ )
+ self._recover_offline_batch_generate_for_step = (step_name, data)
+
+ def _add_batch_for_recovering_offline_batch_generation(self) -> None:
+ """Adds a dummy `_Batch` to the specified step name (it's a `Task` that used an
+ `LLM` with offline batch generation) to recover the pipeline state for offline
+ batch generation in next pipeline executions."""
+ assert self._batch_manager, "Batch manager is not set"
+
+ if self._recover_offline_batch_generate_for_step is None:
+ return
+
+ step_name, data = self._recover_offline_batch_generate_for_step
+ self._logger.debug(
+ f"Adding batch to '{step_name}' step to recover pipeline execution for offline"
+ " batch generation..."
+ )
+ self._batch_manager.add_batch_to_recover_offline_batch_generation(
+ to_step=step_name,
+ data=data,
+ )
+
+ def _register_stages_last_batch(self, batch: "_Batch") -> None:
+ """Registers the last batch received from a step in the `_stages_last_batch`
+ dictionary.
+
+ Args:
+ batch: The last batch received from a step.
+ """
+ _, stages_last_steps = self.dag.get_steps_load_stages()
+ stage_last_steps = stages_last_steps[self._current_stage]
+ if batch.step_name in stage_last_steps:
+ self._stages_last_batch[self._current_stage].append(batch.step_name)
+ self._stages_last_batch[self._current_stage].sort()
def _update_stage(self) -> bool:
"""Checks if the steps of next stage should be loaded and updates `_current_stage`
@@ -838,16 +1073,25 @@ def _should_load_next_stage(self) -> bool:
def _finalize_pipeline_execution(self) -> None:
"""Finalizes the pipeline execution handling the prematurely stop of the pipeline
if required, caching the data and ensuring that all the steps finish its execution."""
- if self._stop_called:
- self._handle_stop()
-
- self._cache()
# Send `None` to steps `input_queue`s just in case some step is still waiting
self._notify_steps_to_stop()
- # Reset flag state
- self._stop_called = False
+ for step_name in self.dag:
+ while self._is_step_running(step_name):
+ self._logger.debug(f"Waiting for step '{step_name}' to finish...")
+ time.sleep(0.5)
+
+ with self._stop_called_lock:
+ if self._stop_called:
+ self._handle_stop()
+
+ # Reset flag state
+ self._stop_called = False
+
+ self._add_batch_for_recovering_offline_batch_generation()
+
+ self._cache()
def _run_load_queue_loop_in_thread(self) -> threading.Thread:
"""Runs a background thread that reads from the `load_queue` to update the status
@@ -898,6 +1142,26 @@ def _is_step_running(self, step_name: str) -> bool:
with self._steps_load_status_lock:
return self._steps_load_status[step_name] >= 1
+ def _steps_to_be_loaded_in_stage(self, stage: int) -> List[str]:
+ """Returns the list of steps of the provided stage that should be loaded taking
+ into account if they have finished.
+
+ Args:
+ stage: the stage number
+
+ Returns:
+ A list containing the name of the steps that should be loaded in this stage.
+ """
+ assert self._batch_manager, "Batch manager is not set"
+
+ steps_stages, _ = self.dag.get_steps_load_stages()
+
+ return [
+ step
+ for step in steps_stages[stage]
+ if not self._batch_manager.step_has_finished(step)
+ ]
+
def _run_stage_steps_and_wait(self, stage: int) -> bool:
"""Runs the steps of the specified stage and waits for them to be ready.
@@ -907,9 +1171,10 @@ def _run_stage_steps_and_wait(self, stage: int) -> bool:
Returns:
`True` if all the steps have been loaded correctly, `False` otherwise.
"""
+ assert self._batch_manager, "Batch manager is not set"
- steps_stages, _ = self.dag.get_steps_load_stages()
- steps = steps_stages[stage]
+ steps = self._steps_to_be_loaded_in_stage(stage)
+ self._logger.debug(f"Steps to be loaded in stage {stage}: {steps}")
# Run the steps of the stage
self._run_steps(steps=steps)
@@ -917,46 +1182,47 @@ def _run_stage_steps_and_wait(self, stage: int) -> bool:
# Wait for them to be ready
self._logger.info(f"⏳ Waiting for all the steps of stage {stage} to load...")
previous_message = None
- while not self._stop_called:
- with self._steps_load_status_lock:
- filtered_steps_load_status = {
- step_name: replicas
- for step_name, replicas in self._steps_load_status.items()
- if step_name in steps
- }
- self._logger.debug(
- f"Steps from stage {stage} loaded: {filtered_steps_load_status}"
- )
-
- if any(
- replicas_loaded == _STEP_LOAD_FAILED_CODE
- for replicas_loaded in filtered_steps_load_status.values()
- ):
- self._logger.error(
- f"❌ Failed to load all the steps of stage {stage}"
- )
- return False
-
- num_steps_loaded = 0
- replicas_message = ""
- for step_name, replicas in filtered_steps_load_status.items():
- step_replica_count = self.dag.get_step_replica_count(step_name)
- if replicas == step_replica_count:
- num_steps_loaded += 1
- replicas_message += f"\n * '{step_name}' replicas: {max(0, replicas)}/{step_replica_count}"
-
- message = f"⏳ Steps from stage {stage} loaded: {num_steps_loaded}/{len(filtered_steps_load_status)}{replicas_message}"
- if num_steps_loaded > 0 and message != previous_message:
- self._logger.info(message)
- previous_message = message
-
- if num_steps_loaded == len(filtered_steps_load_status):
- self._logger.info(
- f"✅ All the steps from stage {stage} have been loaded!"
+ with self._stop_called_lock:
+ while not self._stop_called:
+ with self._steps_load_status_lock:
+ filtered_steps_load_status = {
+ step_name: replicas
+ for step_name, replicas in self._steps_load_status.items()
+ if step_name in steps
+ }
+ self._logger.debug(
+ f"Steps from stage {stage} loaded: {filtered_steps_load_status}"
)
- return True
- time.sleep(2.5)
+ if any(
+ replicas_loaded == _STEP_LOAD_FAILED_CODE
+ for replicas_loaded in filtered_steps_load_status.values()
+ ):
+ self._logger.error(
+ f"❌ Failed to load all the steps of stage {stage}"
+ )
+ return False
+
+ num_steps_loaded = 0
+ replicas_message = ""
+ for step_name, replicas in filtered_steps_load_status.items():
+ step_replica_count = self.dag.get_step_replica_count(step_name)
+ if replicas == step_replica_count:
+ num_steps_loaded += 1
+ replicas_message += f"\n * '{step_name}' replicas: {max(0, replicas)}/{step_replica_count}"
+
+ message = f"⏳ Steps from stage {stage} loaded: {num_steps_loaded}/{len(filtered_steps_load_status)}{replicas_message}"
+ if num_steps_loaded > 0 and message != previous_message:
+ self._logger.info(message)
+ previous_message = message
+
+ if num_steps_loaded == len(filtered_steps_load_status):
+ self._logger.info(
+ f"✅ All the steps from stage {stage} have been loaded!"
+ )
+ return True
+
+ time.sleep(2.5)
return not self._stop_called
@@ -975,6 +1241,9 @@ def _handle_stop(self) -> None:
self._consume_output_queue()
+ if self._should_load_next_stage():
+ self._current_stage += 1
+
def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]:
"""Waits for the input queue of a step to be empty.
@@ -987,7 +1256,9 @@ def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", No
if self._check_step_not_loaded_or_finished(step_name):
return None
- if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME):
+ if input_queue := self.dag.get_step(step_name).get(
+ constants.INPUT_QUEUE_ATTR_NAME
+ ):
while input_queue.qsize() != 0:
pass
return input_queue
@@ -1026,7 +1297,7 @@ def _create_step_input_queue(self, step_name: str) -> "Queue[Any]":
The input queue created.
"""
input_queue = self.QueueClass()
- self.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, input_queue)
+ self.dag.set_step_attr(step_name, constants.INPUT_QUEUE_ATTR_NAME, input_queue)
return input_queue
@abstractmethod
@@ -1048,7 +1319,7 @@ def _run_steps(self, steps: Iterable[str]) -> None:
steps:
"""
for step_name in steps:
- step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ step: "Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
input_queue = self._create_step_input_queue(step_name=step_name)
# Set `pipeline` to `None` as in some Python environments the pipeline is not
@@ -1065,24 +1336,32 @@ def _run_steps(self, steps: Iterable[str]) -> None:
step_num_replicas: int = step.resources.replicas if step.is_normal else 1 # type: ignore
for replica in range(step_num_replicas):
- self._logger.debug(f"Running 1 replica of step '{step.name}'...")
- self._run_step(step=step, input_queue=input_queue, replica=replica)
+ self._logger.debug(
+ f"Running 1 replica of step '{step.name}' with ID {replica}..."
+ )
+ self._run_step(
+ step=step.model_copy(deep=True),
+ input_queue=input_queue,
+ replica=replica,
+ )
def _add_batches_back_to_batch_manager(self) -> None:
"""Add the `Batch`es that were sent to a `Step` back to the `_BatchManager`. This
method should be used when the pipeline has been stopped prematurely."""
for step_name in self.dag:
node = self.dag.get_step(step_name)
- step: "_Step" = node[STEP_ATTR_NAME]
+ step: "_Step" = node[constants.STEP_ATTR_NAME]
if step.is_generator:
continue
- if input_queue := node.get(INPUT_QUEUE_ATTR_NAME):
+ if input_queue := node.get(constants.INPUT_QUEUE_ATTR_NAME):
while not input_queue.empty():
batch = input_queue.get()
if not isinstance(batch, _Batch):
continue
self._batch_manager.add_batch( # type: ignore
- to_step=step_name, batch=batch, prepend=True
+ to_step=step_name,
+ batch=batch,
+ prepend=True,
)
self._logger.debug(
f"Adding batch back to the batch manager: {batch}"
@@ -1097,10 +1376,7 @@ def _consume_output_queue(self) -> None:
batch = self._output_queue.get()
if batch is None:
continue
-
- if batch.step_name in self.dag.leaf_steps:
- self._write_buffer.add_batch(batch) # type: ignore
-
+ self._process_batch(batch, send_last_batch_flag=False)
self._handle_batch_on_stop(batch)
def _manage_batch_flow(self, batch: "_Batch") -> None:
@@ -1114,10 +1390,10 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
"""
assert self._batch_manager, "Batch manager is not set"
- self._register_batch(batch)
-
route_to, do_not_route_to, routed = self._get_successors(batch)
+ self._register_batch(batch)
+
# Keep track of the steps that the batch was routed to
if routed:
batch.batch_routed_to = route_to
@@ -1149,14 +1425,16 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
# If successor step has enough data in its buffer to create a new batch, then
# send the batch to the step.
- if new_batch := self._batch_manager.get_batch(successor):
+ while new_batch := self._batch_manager.get_batch(successor):
self._send_batch_to_step(new_batch)
if not step.is_generator:
# Step ("this", the one from which the batch was received) has enough data on its
# buffers to create a new batch
- if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
+ while new_batch := self._batch_manager.get_batch(step.name): # type: ignore
+ # if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
self._send_batch_to_step(new_batch)
+
else:
self._request_more_batches_if_needed(step)
else:
@@ -1172,7 +1450,7 @@ def _send_to_step(self, step_name: str, to_send: Any) -> None:
step_name: The name of the step.
to_send: The object to send.
"""
- input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME]
+ input_queue = self.dag.get_step(step_name)[constants.INPUT_QUEUE_ATTR_NAME]
input_queue.put(to_send)
def _send_batch_to_step(self, batch: "_Batch") -> None:
@@ -1191,7 +1469,7 @@ def _send_batch_to_step(self, batch: "_Batch") -> None:
)
self._batch_manager.set_last_batch_sent(batch) # type: ignore
- step: "_Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME]
+ step: "_Step" = self.dag.get_step(batch.step_name)[constants.STEP_ATTR_NAME]
if not step.is_generator and (step.is_global or self._use_fs_to_pass_data):
base_path = UPath(self._storage_base_path) / step.name # type: ignore
self._logger.debug(
@@ -1212,7 +1490,7 @@ def _gather_requirements(self) -> List[str]:
"""
steps_requirements = []
for step in self.dag:
- step_req = self.dag.get_step(step)[STEP_ATTR_NAME].requirements
+ step_req = self.dag.get_step(step)[constants.STEP_ATTR_NAME].requirements
steps_requirements.extend(step_req)
return steps_requirements
@@ -1223,7 +1501,10 @@ def _register_batch(self, batch: "_Batch") -> None:
Args:
batch: The batch to register.
"""
- self._batch_manager.register_batch(batch) # type: ignore
+ assert self._batch_manager, "Batch manager is not set"
+ self._batch_manager.register_batch(
+ batch, steps_data_path=self._cache_location["steps_data"]
+ ) # type: ignore
self._logger.debug(
f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch"
" manager"
@@ -1241,13 +1522,12 @@ def _send_last_batch_flag_to_step(self, step_name: str) -> None:
)
for _ in range(self.dag.get_step_replica_count(step_name)):
- self._send_to_step(step_name, LAST_BATCH_SENT_FLAG)
+ self._send_to_step(step_name, constants.LAST_BATCH_SENT_FLAG)
self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore
def _request_initial_batches(self) -> None:
"""Requests the initial batches to the generator steps."""
assert self._batch_manager, "Batch manager is not set"
-
for step in self._batch_manager._steps.values():
if not self._is_step_running(step.step_name):
continue
@@ -1307,8 +1587,10 @@ def _handle_batch_on_stop(self, batch: "_Batch") -> None:
"""
assert self._batch_manager, "Batch manager is not set"
- self._batch_manager.register_batch(batch)
- step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME]
+ self._batch_manager.register_batch(
+ batch, steps_data_path=self._cache_location["steps_data"]
+ )
+ step: "Step" = self.dag.get_step(batch.step_name)[constants.STEP_ATTR_NAME]
for successor in self.dag.get_step_successors(step.name): # type: ignore
self._batch_manager.add_batch(successor, batch)
@@ -1321,7 +1603,7 @@ def _get_step_from_batch(self, batch: "_Batch") -> "Step":
Returns:
The `Step` instance.
"""
- return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME]
+ return self.dag.get_step(batch.step_name)[constants.STEP_ATTR_NAME]
def _notify_steps_to_stop(self) -> None:
"""Notifies the steps to stop their infinite running loop by sending `None` to
@@ -1329,7 +1611,8 @@ def _notify_steps_to_stop(self) -> None:
with self._steps_load_status_lock:
for step_name, replicas in self._steps_load_status.items():
if replicas > 0:
- self._send_to_step(step_name, None)
+ for _ in range(replicas):
+ self._send_to_step(step_name, None)
def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]:
"""Gets the successors and the successors to which the batch has to be routed.
@@ -1342,12 +1625,14 @@ def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]:
a routing function.
"""
node = self.dag.get_step(batch.step_name)
- step: "Step" = node[STEP_ATTR_NAME]
+ step: "Step" = node[constants.STEP_ATTR_NAME]
successors = list(self.dag.get_step_successors(step.name)) # type: ignore
route_to = successors
# Check if the step has a routing function to send the batch to specific steps
- if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME):
+ if routing_batch_function := node.get(
+ constants.ROUTING_BATCH_FUNCTION_ATTR_NAME
+ ):
route_to = routing_batch_function(batch, successors)
successors_str = ", ".join(f"'{successor}'" for successor in route_to)
self._logger.info(
@@ -1424,3 +1709,10 @@ def signal_handler(signumber: int, frame: Any) -> None:
self._stop()
return signal.signal(signal.SIGINT, signal_handler)
+
+
+def set_pipeline_running_env_variables(
+ pipeline_name: str, pipeline_cache_id: str
+) -> None:
+ os.environ[constants.PIPELINE_NAME_ENV_NAME] = pipeline_name
+ os.environ[constants.PIPELINE_CACHE_ID_ENV_NAME] = pipeline_cache_id
diff --git a/src/distilabel/pipeline/batch.py b/src/distilabel/pipeline/batch.py
index d8ad4312ae..684328f53e 100644
--- a/src/distilabel/pipeline/batch.py
+++ b/src/distilabel/pipeline/batch.py
@@ -37,8 +37,11 @@ class _Batch(_Serializable):
data_hash: The hash of the data. Defaults to `None`.
data_path: The path where the data of the batch is stored. Defaults to `None`.
accumulated: A flag to indicate if the batch is accumulated.
- created_from: A dictionary containing the `seq_no` of the batches of the steps that
- were used to create this batch.
+ created_from: A dictionary containing which batches from which steps were used
+ to created this batch. The keys are the names of the steps and the values
+ are lists for each step containing the `seq_no` of each batch used, the original containing the `seq_no` of the batches of the steps that
+ size of the batch used and the number of rows used from the batch to create
+ this batch.
size: The size of the batch.
"""
@@ -49,7 +52,7 @@ class _Batch(_Serializable):
data_hash: Optional[str] = None
data_path: Optional[str] = None
accumulated: bool = False
- created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict)
+ created_from: Dict[str, List[Tuple[int, int, int]]] = field(default_factory=dict)
batch_routed_to: List[str] = field(default_factory=list)
size: int = 0
_fs: Optional[fsspec.AbstractFileSystem] = None
@@ -99,6 +102,7 @@ def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]:
data = self.data[0][:num_rows]
self.data[0] = self.data[0][num_rows:]
+ # self.size = len(self.data[0])
self._update_data_hash()
return data
diff --git a/src/distilabel/pipeline/batch_manager.py b/src/distilabel/pipeline/batch_manager.py
index 42b736f7c8..9ca05e48e2 100644
--- a/src/distilabel/pipeline/batch_manager.py
+++ b/src/distilabel/pipeline/batch_manager.py
@@ -13,9 +13,10 @@
# limitations under the License.
from collections import defaultdict
+from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from distilabel.constants import (
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
@@ -70,6 +71,14 @@ class _BatchManagerStep(_Serializable):
batch from step A used by steps B and C and obtained from the `created_from`
of the batches created by them. It's used to avoid messing up the order of the
batches. Only used if `convergence_step=True`. Defaults to `0`.
+ step_signature: The signature that defines a given `Step`. It will be used for the
+ caching mechanism.
+ use_cache: Flag from the original `Step` to indicate whether this step should make use of
+ the cached data.
+ step_offset: Dictionary with each key the predecessor/s step/s and as value a dict
+ with keys `batch` and `offset`, containing the name of the file for the corresponding
+ batch, and the number of rows that were read from that step, respectively. Used
+ for caching mechanism.
"""
step_name: str
@@ -85,6 +94,9 @@ class _BatchManagerStep(_Serializable):
)
next_expected_created_from_batch_seq_no: int = 0
next_expected_seq_no: Dict[str, Tuple[int, int]] = field(default_factory=dict)
+ step_signature: Optional[str] = None
+ use_cache: bool = False
+ step_offset: Dict[str, Tuple[int, int]] = field(default_factory=dict)
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
@@ -124,14 +136,22 @@ def get_batch(self) -> Union[_Batch, None]:
if not self._ready_to_create_batch():
return None
+ seq_no = self._get_seq_no()
+
# `_last_batch` must be called before `_get_data`, as `_get_data` will update the
# list of data which is used to determine if the batch to be created is the last one.
- # TODO: remove `_last_batch` method and integrate logic in `_get_data`
last_batch = self._last_batch()
+
+ # Get the batch data and the information from which batches of the upstream steps
+ # the data was taken.
data, created_from, batch_routed_to = self._get_data()
+ # Update the step offset i.e. which is the last batch and last row index from that
+ # batch that the step has consumed
+ self._update_offset(created_from)
+
return _Batch(
- seq_no=self._get_seq_no(),
+ seq_no=seq_no,
step_name=self.step_name,
last_batch=last_batch,
data=data,
@@ -212,6 +232,9 @@ def from_step(
data={predecessor: [] for predecessor in predecessors},
convergence_step=convergence_step,
next_expected_seq_no={predecessor: (0, 0) for predecessor in predecessors},
+ step_signature=step.signature,
+ use_cache=step.use_cache,
+ step_offset={predecessor: (0, 0) for predecessor in predecessors},
)
def _get_seq_no(self) -> int:
@@ -226,7 +249,9 @@ def _get_seq_no(self) -> int:
def _get_data(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
+ ) -> Tuple[
+ List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
+ ]:
"""Gets the data needed to create a batch for the step to process. If the step is
accumulating data, then it will return a list with all the data received from the
predecessors. Otherwise, it will return a list of data with the `input_batch_size`
@@ -252,7 +277,7 @@ def _get_data(
def _get_data_for_accumulate(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step
is accumulating data. It will return a list with all the data received from the
predecessors. In addition, it will remove the data used to create the batch from
@@ -268,7 +293,7 @@ def _get_data_for_accumulate(
for step_name, batches in self.data.items():
batches_used[step_name] = []
for batch in batches:
- batches_used[step_name].append((batch.seq_no, batch.size))
+ batches_used[step_name].append((batch.seq_no, batch.size, batch.size))
data.append([row for batch in batches for row in batch.get_data()])
# Reset the data buffer
self.data = {step_name: [] for step_name in self.data}
@@ -276,7 +301,7 @@ def _get_data_for_accumulate(
def _get_data_for_convergence_step(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step is
a convergence step.
@@ -315,7 +340,7 @@ def _get_data_for_convergence_step(
remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore
# Keep track of the batches used to create the batch
- batches_used[batch.step_name].append((batch.seq_no, batch.size))
+ batches_used[batch.step_name].append((batch.seq_no, batch.size, num_rows))
# If the batch was entirely consumed, then remove it from the buffer
if len(batch.data[0]) == 0:
@@ -336,7 +361,9 @@ def _get_data_for_convergence_step(
def _get_data_normal(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
+ ) -> Tuple[
+ List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
+ ]:
"""Gets the data needed to create a batch for the step to process when the step is
not accumulating data. It will return a list of data with the `input_batch_size`
for each predecessor. In addition, it will remove the data used to create the batch
@@ -374,7 +401,7 @@ def _get_data_normal(
remaining_rows -= num_rows
# Keep track of the batches used to create the batch
- batches_used[step_name].append((batch.seq_no, batch.size))
+ batches_used[step_name].append((batch.seq_no, batch.size, num_rows))
next_expected_seq_no = batch.seq_no
@@ -508,9 +535,11 @@ def _ready_to_create_batch_normal(self) -> bool:
# `batches` are sorted by `seq_no`
num_rows = 0
+ is_batch_in_order = True
for batch in batches:
# Need to create batches using the data from batches with sequential `seq_no`
if batch.seq_no != next_expected_seq_no:
+ is_batch_in_order = False
break
# There are enough rows to create a batch
num_rows += len(batch.data[0])
@@ -524,11 +553,12 @@ def _ready_to_create_batch_normal(self) -> bool:
return False
# If there are not enough rows and the last batch was not received yet, then
- # there is not enough data yet to creata a batch
+ # there is not enough data yet to create a batch
+ # If the last batch was received, the batch preceding it must be in order
if (
self.input_batch_size
and num_rows < self.input_batch_size
- and step_name not in self.last_batch_received
+ and not (step_name in self.last_batch_received and is_batch_in_order)
):
return False
@@ -549,6 +579,35 @@ def _last_batch(self) -> bool:
return self._last_batch_normal()
+ def _update_offset(
+ self, created_from: Dict[str, List[Tuple[int, int, int]]]
+ ) -> None:
+ """Update the offset for the batch buffers of the upstream steps.
+
+ Args:
+ created_from: A dictionary containing which batches from which steps were used
+ to created this batch. The keys are the names of the steps and the values
+ are lists for each step containing the `seq_no` of each batch used, the original containing the `seq_no` of the batches of the steps that
+ size of the batch used and the number of rows used from the batch to create
+ this batch.
+ """
+ for predecessor, seq_no_and_batch in created_from.items():
+ prev_last_batch_seq_no, prev_last_batch_offset = self.step_offset[
+ predecessor
+ ]
+ last_batch_seq_no, _, last_batch_size = seq_no_and_batch[-1]
+ batch_offset = (
+ prev_last_batch_offset + last_batch_size
+ if prev_last_batch_seq_no == last_batch_seq_no
+ else last_batch_size
+ )
+ last_batch_seq_no = (
+ last_batch_seq_no
+ if last_batch_seq_no > prev_last_batch_seq_no
+ else prev_last_batch_seq_no
+ )
+ self.step_offset[predecessor] = (last_batch_seq_no, batch_offset)
+
def _last_batch_accumulate(self) -> bool:
"""Checks if the batch to be created is the last one for an step accumulating data.
`True` if the last batch was received from all the predecessors.
@@ -593,11 +652,7 @@ def _last_batch_normal(self) -> bool:
num_rows = sum(len(batch.data[0]) for batch in batches)
- if (
- self.input_batch_size
- and num_rows > self.input_batch_size
- and step_name in self.last_batch_received
- ):
+ if self.input_batch_size and num_rows > self.input_batch_size:
return False
return True
@@ -616,12 +671,12 @@ def _group_batches_by_created_from(
for batches in self.data.values():
for batch in batches:
first_key = next(iter(batch.created_from))
- batch_seq_no, batch_size = batch.created_from[first_key][0]
+ batch_seq_no, batch_size, _ = batch.created_from[first_key][0]
grouped_batches[batch_seq_no].append((batch, batch_size))
return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items())
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
- """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function.
+ """Dumps the content of the `_BatchManagerStep` to a dictionary.
Args:
obj: Unused, just kept to match the signature of the parent method.
@@ -645,8 +700,15 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"convergence_step_batches_consumed": self.convergence_step_batches_consumed,
"next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no,
"next_expected_seq_no": self.next_expected_seq_no,
+ "step_signature": self.step_signature,
+ "use_cache": self.use_cache,
+ "step_offset": self.step_offset,
}
+ @property
+ def signature(self) -> str:
+ return f"{self.step_name}_{self.step_signature}"
+
class _BatchManager(_Serializable):
"""Class to manage the batches received from the steps. It keeps track of the
@@ -672,9 +734,9 @@ def __init__(
Args:
steps: A dictionary with the step name as the key and a dictionary with the
predecessor step name as the key and a list of batches as the value.
- last_batch_received: A dictionary with the step name as the key and a the last
+ last_batch_received: A dictionary with the step name as the key and the last
`_Batch` received from the step.
- last_batch_sent: A dictionary with the step name as the key and a the last
+ last_batch_sent: A dictionary with the step name as the key and the last
`_Batch` sent to the step.
last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG`
was sent.
@@ -692,7 +754,6 @@ def can_generate(self) -> bool:
`True` if there are still batches to be processed by the steps. Otherwise,
`False`.
"""
-
for step_name, batch in self._last_batch_received.items():
if step_name not in self._last_batch_flag_sent_to:
if not batch:
@@ -706,17 +767,38 @@ def can_generate(self) -> bool:
return False
- def register_batch(self, batch: _Batch) -> None:
+ def register_batch(
+ self, batch: _Batch, steps_data_path: Optional["StrOrPath"] = None
+ ) -> None:
"""Method to register a batch received from a step. It will keep track of the
sequence number and the last batch received from the step in the internal maps.
Args:
batch: _Batch from which we will register the sequence number and the last batch received.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
"""
last_batch = self._last_batch_received[batch.step_name]
if not last_batch or (last_batch and last_batch.seq_no < batch.seq_no):
self._last_batch_received[batch.step_name] = batch
+ if steps_data_path:
+ self.write_batch_data(batch, steps_data_path)
+
+ def write_batch_data(self, batch: _Batch, steps_data_path: Path) -> None:
+ """Writes the batch to the steps data directory.
+
+ Argument:
+ batch: the batch to be written.
+ steps_data_path: the steps data base directory.
+ """
+ step = self._steps[batch.step_name]
+ batch_manager_data_dir = Path(steps_data_path) / step.signature
+ batch_manager_data_dir.mkdir(parents=True, exist_ok=True)
+ filename = batch_manager_data_dir / f"batch_{batch.seq_no}.json"
+ if not filename.exists():
+ self.save(path=filename, format="json", dump=batch.dump())
+
def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
"""Gets the last batch received from a step.
@@ -728,7 +810,12 @@ def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
"""
return self._last_batch_received.get(step_name)
- def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
+ def add_batch(
+ self,
+ to_step: str,
+ batch: _Batch,
+ prepend: bool = False,
+ ) -> None:
"""Add an output batch from `batch.step_name` to `to_step`.
Args:
@@ -742,10 +829,27 @@ def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
"""
if to_step not in self._steps:
raise ValueError(f"Step '{to_step}' not found in the batch manager.")
-
step = self._steps[to_step]
step.add_batch(batch, prepend)
+ def add_batch_to_recover_offline_batch_generation(
+ self, to_step: str, data: List[List[Dict[str, Any]]]
+ ) -> None:
+ """Add a batch to recover pipeline execution from an `_Step` that used an `LLM`
+ with offline batch generation. It will add the batch to the start of the buffer
+ of the step and set the last batch received of the step to `None`.
+
+ Args:
+ to_step: The name of the step that will process the batch.
+ data: The data that was used with the offline batch generation.
+ """
+ self.add_batch(
+ to_step=to_step,
+ batch=_Batch(seq_no=0, step_name=to_step, last_batch=True, data=data),
+ prepend=True,
+ )
+ self._last_batch_received[to_step] = None
+
def get_batch(self, step_name: str) -> Union[_Batch, None]:
"""Get the next batch to be processed by the step.
@@ -803,25 +907,46 @@ def set_last_batch_flag_sent_to(self, step_name: str) -> None:
def set_next_expected_seq_no(
self, step_name: str, from_step: str, next_expected_seq_no: int
) -> None:
- """Sets the next expected sequence number of a `_Batch` received by `step` comming
+ """Sets the next expected sequence number of a `_Batch` received by `step` coming
from `from_step`.
Args:
- step_name: The step name which next expected sequence number for `from_step`
+ step_name: The step name whose next expected sequence number for `from_step`
has to be updated.
from_step: The name of the step from which its next expected sequence number
in step has to be updated.
- next_expected_seq_no: the next expected sequence number of a `_Batch` comming
+ next_expected_seq_no: the next expected sequence number of a `_Batch` coming
from `from_step`.
"""
self._steps[step_name].set_next_expected_seq_no(from_step, next_expected_seq_no)
+ def step_has_finished(self, step_name: str) -> bool:
+ """Indicates if the step has finished by checking if it sent a batch with `last_batch==True`
+ or it was sent the `LAST_BATCH_SENT_FLAG`.
+
+ Args:
+ step_name: the name of the step to be checked.
+
+ Returns:
+ `True` if step has finished generating batches, `False` otherwise.
+ """
+ return step_name in self._last_batch_flag_sent_to or (
+ self._last_batch_received[step_name] is not None
+ and self._last_batch_received[step_name].last_batch # type: ignore
+ )
+
@classmethod
- def from_dag(cls, dag: "DAG") -> "_BatchManager":
+ def from_dag( # noqa: C901
+ cls, dag: "DAG", use_cache: bool = False, steps_data_path: Optional[Path] = None
+ ) -> "_BatchManager":
"""Create a `_BatchManager` instance from a `DAG` instance.
Args:
dag: The `DAG` instance.
+ use_cache: whether or not to try loading outputs from steps of previous pipelines
+ executions. Defaults to `False`.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
Returns:
A `_BatchManager` instance.
@@ -829,12 +954,14 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager":
steps = {}
last_batch_received = {}
last_batch_sent = {}
+ last_batch_flag_sent_to = []
+
+ load_batches = {}
+ steps_to_load_data_from_previous_executions: Dict[str, Union[Path, None]] = {}
for step_name in dag:
step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME]
last_batch_received[step.name] = None
last_batch_sent[step.name] = None
- if step.is_generator:
- continue
predecessors = list(dag.get_step_predecessors(step_name))
convergence_step = all(
dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
@@ -845,8 +972,55 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager":
predecessors=predecessors,
convergence_step=convergence_step,
)
+
+ all_step_precessors_use_cache = all(
+ dag.get_step(step_name)[STEP_ATTR_NAME].use_cache
+ for step_name in predecessors
+ )
+ if use_cache and step.use_cache and all_step_precessors_use_cache:
+ step_data_path = steps_data_path / batch_manager_step.signature
+ if step_data_path.exists():
+ steps_to_load_data_from_previous_executions[step_name] = (
+ step_data_path
+ )
+ # We only want to load the outputs that are directly needed by the added
+ # steps, so if we need to load the outputs of one step and one of its
+ # predecessors it's in the list, then we remove it.
+ for predecessor in predecessors:
+ if predecessor in steps_to_load_data_from_previous_executions:
+ steps_to_load_data_from_previous_executions[predecessor] = (
+ None
+ )
+
steps[step_name] = batch_manager_step
- return cls(steps, last_batch_received, last_batch_sent, [])
+
+ for (
+ step_name,
+ step_outputs_path,
+ ) in steps_to_load_data_from_previous_executions.items():
+ last_batch_flag_sent_to.append(step_name)
+ if step_outputs_path is None:
+ continue
+ load_batches[step_name] = sorted(
+ [
+ _Batch.from_json(batch_file)
+ for batch_file in step_outputs_path.glob("*.json")
+ if batch_file.is_file() and batch_file.suffix == ".json"
+ ],
+ key=lambda x: x.seq_no,
+ )
+ last_batch_received[step_name] = load_batches[step_name][-1]
+
+ # Load batches from previous steps in batch manager steps
+ for step_name, batch_manager_step in steps.items():
+ for predecessor in dag.get_step_predecessors(step_name):
+ if predecessor in load_batches:
+ batch_manager_step.data[predecessor] = deepcopy(
+ load_batches[predecessor]
+ )
+ batch_manager_step.last_batch_received.append(predecessor)
+
+ return cls(steps, last_batch_received, last_batch_sent, last_batch_flag_sent_to)
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManager` to a dictionary.
@@ -871,12 +1045,14 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"last_batch_flag_sent_to": self._last_batch_flag_sent_to,
}
- def cache(self, path: "StrOrPath") -> None:
+ def cache(self, path: Path, steps_data_path: Path) -> None: # noqa: C901
"""Cache the `_BatchManager` to a file.
Args:
path: The path to the file where the `_BatchManager` will be cached. If `None`,
then the `_BatchManager` will be cached in the default cache folder.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
"""
def save_batch(
@@ -932,26 +1108,6 @@ def remove_files(keep_files: List[str], dir: Path) -> None:
# Remove built `_Batch`es that were consumed from cache
remove_files(step_dump["built_batches"], built_batches_dir)
- # Store each `_BatchManagerStep` `_Batch`es in a separate file
- for buffered_step_name in step_dump["data"]:
- step_batches_dir = batch_manager_step_dir / buffered_step_name
- step_batches_dir.mkdir(parents=True, exist_ok=True)
-
- # Store each `_Batch` in a separate file
- step_dump["data"][buffered_step_name] = [
- str(
- save_batch(
- batches_dir=step_batches_dir,
- batch_dump=batch_dump,
- batch_list=self._steps[step_name].data[buffered_step_name],
- )
- )
- for batch_dump in step_dump["data"][buffered_step_name]
- ]
-
- # Remove `_Batch`es that were consumed from cache
- remove_files(step_dump["data"][buffered_step_name], step_batches_dir)
-
# Store the `_BatchManagerStep` info
batch_manager_step_file = str(
path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json"
@@ -965,29 +1121,138 @@ def remove_files(keep_files: List[str], dir: Path) -> None:
self.save(path=path, format="json", dump=dump)
@classmethod
- def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager":
+ def load_from_cache(
+ cls, dag: "DAG", batch_manager_path: "StrOrPath", steps_data_path: "StrOrPath"
+ ) -> "_BatchManager":
"""Loads the `_BatchManager` from a cache file.
Args:
path: The path to the cache file.
"""
- _check_is_dir(path)
- content = read_json(path)
+ _check_is_dir(batch_manager_path)
+ content = read_json(batch_manager_path)
# Read each `_BatchManagerStep` from file
steps = {}
for step_name, step_file in content["steps"].items():
steps[step_name] = read_json(step_file)
+ # When reading back from JSON, `next_expected_seq_no` and `step_offset` is a
+ # list (because JSON files do not have tuples).
+ steps[step_name]["next_expected_seq_no"] = {
+ k: tuple(v) for k, v in steps[step_name]["next_expected_seq_no"].items()
+ }
+ steps[step_name]["step_offset"] = {
+ k: tuple(v) for k, v in steps[step_name]["step_offset"].items()
+ }
+
+ # TODO: where are we writing built batches now? xD
# Read each `_Batch` from file
steps[step_name]["built_batches"] = [
read_json(batch) for batch in steps[step_name]["built_batches"]
]
- for buffered_step_name, batch_files in steps[step_name]["data"].items():
- steps[step_name]["data"][buffered_step_name] = [
- read_json(batch_file) for batch_file in batch_files
- ]
+ # Read the batches from the `steps_data` directory to populate back the `_BatchManagerStep`
+ step_offset = steps[step_name]["step_offset"]
+ for successor_step_name, offset in step_offset.items():
+ batch_offset, batch_row_offset = offset
+ step: "_Step" = dag.get_step(successor_step_name)[STEP_ATTR_NAME]
+ successor_step_data_path = (
+ steps_data_path / f"{step.name}_{step.signature}"
+ )
+
+ # read batches from successor step from the step data directory taking into
+ # account offset
+ batches = []
+ for batch_file in successor_step_data_path.glob("*.json"):
+ if not batch_file.is_file() or batch_file.suffix != ".json":
+ continue
+
+ # If the batch number is lower than the batch offset then we should
+ # skip it as it has already been processed by the step
+ batch_no = int(batch_file.stem.split("batch_")[1])
+ if batch_no < batch_offset:
+ continue
+
+ # read the batch and skip the first N rows of the first batch
+ batch = read_json(batch_file)
+ if batch_no == batch_offset:
+ batch["data"][0] = batch["data"][0][batch_row_offset:]
+
+ batches.append(batch)
+
+ # sort batches by `seq_no` as it's a requirement for checking if ready to
+ # create next batch
+ batches.sort(key=lambda batch: batch["seq_no"])
+ steps[step_name]["data"][successor_step_name] = batches
content["steps"] = steps
return cls.from_dict(content)
+
+ def invalidate_cache_for(
+ self, step_name: str, dag: "DAG", steps_data_path: Path
+ ) -> None:
+ """Invalidates the cache for the given step and its predecessors.
+
+ Args:
+ step_name: the name of the step for which the cache will be invalidated.
+ dag: the `DAG` of the pipeline containing the steps.
+ steps_data_path: the path where the output batches of each `Step` were saved
+ for reuse in another pipeline execution.
+ """
+ invalidate_if_predecessor = []
+ for sorted_step in dag:
+ if (sorted_step == step_name) or any(
+ predecessor in invalidate_if_predecessor
+ for predecessor in dag.get_step_predecessors(sorted_step)
+ ):
+ self._reset_batch_manager_for_step(sorted_step, dag)
+ invalidate_if_predecessor.append(sorted_step)
+
+ self._load_predecessor_batches(step_name, dag, steps_data_path)
+
+ def _reset_batch_manager_for_step(self, step_name: str, dag: "DAG") -> None:
+ """Resets the batch manager state for a given step i.e. creates a new clean `_BatchManagerStep`
+ for the step and removes the step name from the lists of states of the `BatchManager`.
+
+ Args:
+ step_name: the name of step for which its batch manager state needs to be cleaned.
+ dag: the `DAG` of the pipeline containing the steps.
+ """
+ predecessors = list(dag.get_step_predecessors(step_name))
+ convergence_step = dag.is_convergence_step(step_name)
+ step = dag.get_step(step_name)[STEP_ATTR_NAME]
+ self._steps[step_name] = _BatchManagerStep.from_step(
+ step, predecessors=predecessors, convergence_step=convergence_step
+ )
+
+ self._last_batch_received[step_name] = None
+ self._last_batch_sent[step_name] = None
+ if step_name in self._last_batch_flag_sent_to:
+ self._last_batch_flag_sent_to.remove(step_name)
+
+ def _load_predecessor_batches(
+ self, step_name: str, dag: "DAG", steps_data_path: Path
+ ) -> None:
+ """Loads the cached batches of the predecessors of the step in its `_BatchManagerStep`.
+
+ Args:
+ step_name: the name of the step whose predecessors' batches will be loaded.
+ dag: the `DAG` of the pipeline containing the steps.
+ steps_data_path: the path where the output batches of each `Step` were saved
+ for reuse in another pipeline execution.
+ """
+ for predecessor in dag.get_step_predecessors(step_name):
+ step_predecessor = dag.get_step(predecessor)[STEP_ATTR_NAME]
+ predecessor_step_data_path = (
+ steps_data_path
+ / f"{step_predecessor.name}_{step_predecessor.signature}"
+ )
+ batch_files = list_files_in_dir(
+ predecessor_step_data_path, key=lambda x: int(x.stem.split("_")[-1])
+ )
+ for file in batch_files:
+ batch = _Batch.from_file(file)
+ if batch.last_batch:
+ self._steps[step_name].last_batch_received.append(batch.step_name)
+ self._steps[step_name].data[predecessor].append(batch)
diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py
index 74bfd0492b..c01cce303f 100644
--- a/src/distilabel/pipeline/local.py
+++ b/src/distilabel/pipeline/local.py
@@ -16,20 +16,31 @@
import signal
import sys
from multiprocessing.pool import Pool
-from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Union,
+ cast,
+)
import tblib
+from distilabel.constants import SIGINT_HANDLER_CALLED_ENV_NAME
from distilabel.distiset import create_distiset
-from distilabel.pipeline.base import (
- BasePipeline,
-)
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
+from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.ray import RayPipeline
from distilabel.pipeline.step_wrapper import _StepWrapper, _StepWrapperException
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.ray import script_executed_in_ray_cluster
if TYPE_CHECKING:
+ import logging
from queue import Queue
from distilabel.distiset import Distiset
@@ -40,13 +51,27 @@
_SUBPROCESS_EXCEPTION: Union[Exception, None] = None
-def _init_worker(log_queue: "Queue[Any]") -> None:
+def _init_worker(
+ log_queue: "Queue[Any]", pipeline_name: str, pipeline_cache_id: str
+) -> None:
"""Init function for the child processes that will execute the `Step`s of the `Pipeline`.
Args:
log_queue: The queue to send the logs to the main process.
"""
- signal.signal(signal.SIGINT, signal.SIG_IGN)
+
+ # Register a signal handler for SIGINT to avoid the default behavior of the process
+ # to terminate when the parent process receives a SIGINT signal. Instead, set an env
+ # variable when SIGINT is received. Child process can check the value of this env
+ # variable in sections of the code where they need to stop the execution if SIGINT
+ # was received (such as offline batch generation polling).
+ def signal_handler(sig: int, frame: Any) -> None:
+ import os
+
+ os.environ[SIGINT_HANDLER_CALLED_ENV_NAME] = "1"
+
+ signal.signal(signal.SIGINT, signal_handler)
+ set_pipeline_running_env_variables(pipeline_name, pipeline_cache_id)
setup_logging(log_queue)
@@ -122,11 +147,12 @@ def ray(
def run(
self,
- parameters: Optional[Dict[str, Dict[str, Any]]] = None,
+ parameters: Optional[Dict[Any, Dict[str, Any]]] = None,
use_cache: bool = True,
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline.
@@ -149,6 +175,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ logging_handlers: A list of logging handlers that will be used to log the
+ output of the pipeline. This argument can be useful so the logging messages
+ can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
@@ -169,11 +198,12 @@ def run(
self._log_queue = cast("Queue[Any]", mp.Queue())
if distiset := super().run(
- parameters,
- use_cache,
- storage_parameters,
- use_fs_to_pass_data,
+ parameters=parameters,
+ use_cache=use_cache,
+ storage_parameters=storage_parameters,
+ use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
+ logging_handlers=logging_handlers,
):
return distiset
@@ -183,7 +213,11 @@ def run(
_NoDaemonPool(
num_processes,
initializer=_init_worker,
- initargs=(self._log_queue,),
+ initargs=(
+ self._log_queue,
+ self.name,
+ self.signature,
+ ),
) as pool,
):
self._manager = manager
@@ -288,6 +322,21 @@ def _error_callback(self, e: BaseException) -> None:
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
return
+ # Handle tasks using an `LLM` using offline batch generation
+ if isinstance(
+ e.subprocess_exception, DistilabelOfflineBatchGenerationNotFinishedException
+ ):
+ self._logger.info(
+ f"⏹️ '{e.step.name}' task stopped pipeline execution: LLM offline batch"
+ " generation in progress. Rerun pipeline with cache to check results and"
+ " continue execution."
+ )
+ self._set_step_for_recovering_offline_batch_generation(e.step, e.data) # type: ignore
+ with self._stop_called_lock:
+ if not self._stop_called:
+ self._stop(acquire_lock=False)
+ return
+
# Global step with successors failed
self._logger.error(f"An error occurred in global step '{step_name}'")
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
@@ -324,38 +373,45 @@ def _set_steps_not_loaded_exception(self) -> None:
)
self._exception.__cause__ = _SUBPROCESS_EXCEPTION
- def _stop(self) -> None:
+ def _stop(self, acquire_lock: bool = True) -> None:
"""Stops the pipeline execution. It will first send `None` to the input queues
of all the steps and then wait until the output queue is empty i.e. all the steps
finished processing the batches that were sent before the stop flag. Then it will
- send `None` to the output queue to notify the pipeline to stop."""
+ send `None` to the output queue to notify the pipeline to stop.
+
+ Args:
+ acquire_lock: Whether to acquire the lock to access the `_stop_called` attribute.
+ """
- with self._stop_called_lock:
- if self._stop_called:
- self._stop_calls += 1
- if self._stop_calls == 1:
- self._logger.warning(
- "🛑 Press again to force the pipeline to stop."
- )
- elif self._stop_calls > 1:
- self._logger.warning("🛑 Forcing pipeline interruption.")
+ if acquire_lock:
+ self._stop_called_lock.acquire()
- if self._pool:
- self._pool.terminate()
- self._pool.join()
- self._pool = None
+ if self._stop_called:
+ self._stop_calls += 1
+ if self._stop_calls == 1:
+ self._logger.warning("🛑 Press again to force the pipeline to stop.")
+ elif self._stop_calls > 1:
+ self._logger.warning("🛑 Forcing pipeline interruption.")
- if self._manager:
- self._manager.shutdown()
- self._manager.join()
- self._manager = None
+ if self._pool:
+ self._pool.terminate()
+ self._pool.join()
+ self._pool = None
- stop_logging()
+ if self._manager:
+ self._manager.shutdown()
+ self._manager.join()
+ self._manager = None
- sys.exit(1)
+ stop_logging()
- return
- self._stop_called = True
+ sys.exit(1)
+
+ return
+ self._stop_called = True
+
+ if acquire_lock:
+ self._stop_called_lock.release()
self._logger.debug(
f"Steps loaded before calling `stop`: {self._steps_load_status}"
@@ -364,5 +420,4 @@ def _stop(self) -> None:
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
- self._stop_load_queue_loop()
self._stop_output_queue_loop()
diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py
index aad72a61cb..70bf205ab3 100644
--- a/src/distilabel/pipeline/ray.py
+++ b/src/distilabel/pipeline/ray.py
@@ -15,15 +15,17 @@
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
-from distilabel.constants import INPUT_QUEUE_ATTR_NAME
+from distilabel.constants import INPUT_QUEUE_ATTR_NAME, STEP_ATTR_NAME
from distilabel.distiset import create_distiset
+from distilabel.errors import DistilabelUserError
from distilabel.llms.vllm import vLLM
-from distilabel.pipeline.base import BasePipeline
+from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.step_wrapper import _StepWrapper
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.serialization import TYPE_INFO_KEY
if TYPE_CHECKING:
+ import logging
from os import PathLike
from queue import Queue
@@ -81,6 +83,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline in the Ray cluster.
@@ -103,6 +106,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ logging_handlers: A list of logging handlers that will be used to log the
+ output of the pipeline. This argument can be useful so the logging messages
+ can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
@@ -110,6 +116,8 @@ def run(
Raises:
RuntimeError: If the pipeline fails to load all the steps.
"""
+ self._check_no_llms_using_offline_batch_generation()
+
self._init_ray()
self._log_queue = self.QueueClass(
@@ -117,11 +125,12 @@ def run(
)
if distiset := super().run(
- parameters,
- use_cache,
- storage_parameters,
- use_fs_to_pass_data,
+ parameters=parameters,
+ use_cache=use_cache,
+ storage_parameters=storage_parameters,
+ use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
+ logging_handlers=logging_handlers,
):
return distiset
@@ -160,6 +169,21 @@ def run(
return distiset
+ def _check_no_llms_using_offline_batch_generation(self) -> None:
+ """Checks if there are any `LLM` steps using the `offline_batch_generate` method
+ and raises an exception if so. This method is not supported in the Ray pipeline."""
+ for step_name in self.dag:
+ step: "_Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
+ if not hasattr(step, "llm"):
+ continue
+ if step.llm.use_offline_batch_generation: # type: ignore
+ raise DistilabelUserError(
+ f"Step '{step_name}' uses an `LLM` with offline batch generation because"
+ "`use_offline_batch_generation=True`. `LLM`s using this method are not"
+ " supported in the Ray pipeline.",
+ page="sections/how_to_guides/advanced/offline-batch-generation",
+ )
+
def _init_ray(self) -> None:
"""Inits or connects to a Ray cluster."""
try:
@@ -231,13 +255,22 @@ def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> N
@ray.remote
class _StepWrapperRay:
def __init__(
- self, step_wrapper: _StepWrapper, log_queue: "Queue[Any]"
+ self,
+ step_wrapper: _StepWrapper,
+ log_queue: "Queue[Any]",
+ pipeline_name: str,
+ pipeline_cache_id: str,
) -> None:
self._step_wrapper = step_wrapper
self._log_queue = log_queue
+ self._pipeline_name = pipeline_name
+ self._pipeline_cache_id = pipeline_cache_id
def run(self) -> str:
setup_logging(log_queue=self._log_queue)
+ set_pipeline_running_env_variables(
+ self._pipeline_name, self._pipeline_cache_id
+ )
return self._step_wrapper.run()
resources: Dict[str, Any] = {
@@ -276,6 +309,8 @@ def run(self) -> str:
ray_pipeline=True,
),
log_queue=self._log_queue,
+ pipeline_name=self.name,
+ pipeline_cache_id=self.signature,
)
self._logger.debug(
@@ -399,7 +434,6 @@ def _stop(self) -> None:
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
- self._stop_load_queue_loop()
self._stop_output_queue_loop()
def dump(self, **kwargs: Any) -> Dict[str, Any]:
diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py
index ee2ca0f8c8..e29a520405 100644
--- a/src/distilabel/pipeline/routing_batch_function.py
+++ b/src/distilabel/pipeline/routing_batch_function.py
@@ -19,6 +19,7 @@
from pydantic import BaseModel, PrivateAttr
from typing_extensions import Self
+from distilabel.errors import DistilabelUserError
from distilabel.utils.serialization import (
TYPE_INFO_KEY,
_get_module_attr,
@@ -134,19 +135,21 @@ def __rshift__(
routing batch function.
"""
if not isinstance(other, list):
- raise ValueError(
+ raise DistilabelUserError(
f"Can only set a `routing_batch_function` for a list of steps. Got: {other}."
" Please, review the right-hand side of the `routing_batch_function >> other`"
" expression. It should be"
- " `upstream_step >> routing_batch_function >> [downstream_step_1, dowstream_step_2, ...]`."
+ " `upstream_step >> routing_batch_function >> [downstream_step_1, dowstream_step_2, ...]`.",
+ page="sections/how_to_guides/basic/pipeline/?h=routing#routing-batches-to-specific-downstream-steps",
)
if not self._step:
- raise ValueError(
+ raise DistilabelUserError(
"Routing batch function doesn't have an upstream step. Cannot connect downstream"
" steps before connecting the upstream step. Connect this routing batch"
" function to an upstream step using the `>>` operator. For example:"
- " `upstream_step >> routing_batch_function >> [downstream_step_1, downstream_step_2, ...]`."
+ " `upstream_step >> routing_batch_function >> [downstream_step_1, downstream_step_2, ...]`.",
+ page="sections/how_to_guides/basic/pipeline/?h=routing#routing-batches-to-specific-downstream-steps",
)
for step in other:
diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py
index ad4a668bf2..844648f202 100644
--- a/src/distilabel/pipeline/step_wrapper.py
+++ b/src/distilabel/pipeline/step_wrapper.py
@@ -17,6 +17,8 @@
from typing import Any, Dict, List, Optional, Union, cast
from distilabel.constants import LAST_BATCH_SENT_FLAG
+from distilabel.errors import DISTILABEL_DOCS_URL
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.typing import StepLoadStatus
@@ -74,7 +76,7 @@ def _init_cuda_device_placement_mixin(attr: CudaDevicePlacementMixin) -> None:
attr.disable_cuda_device_placement = True
else:
desired_num_gpus = self.step.resources.gpus or 1
- attr._llm_identifier = self.step.name
+ attr._llm_identifier = f"{self.step.name}-replica-{self.replica}"
attr._desired_num_gpus = desired_num_gpus
for field_name in self.step.model_fields_set:
@@ -132,14 +134,23 @@ def run(self) -> str:
def _notify_load(self) -> None:
"""Notifies that the step has finished executing its `load` function successfully."""
+ self.step._logger.debug(
+ f"Notifying load of step '{self.step.name}' (replica ID {self.replica})..."
+ )
self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore
def _notify_unload(self) -> None:
"""Notifies that the step has been unloaded."""
+ self.step._logger.debug(
+ f"Notifying unload of step '{self.step.name}' (replica ID {self.replica})..."
+ )
self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore
def _notify_load_failed(self) -> None:
"""Notifies that the step failed to load."""
+ self.step._logger.debug(
+ f"Notifying load failed of step '{self.step.name}' (replica ID {self.replica})..."
+ )
self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore
def _generator_step_process_loop(self) -> None:
@@ -155,6 +166,7 @@ def _generator_step_process_loop(self) -> None:
`process` method.
"""
step = cast("GeneratorStep", self.step)
+
try:
if (batch := self.input_queue.get()) is None:
self.step._logger.info(
@@ -230,7 +242,16 @@ def _non_generator_process_loop(self) -> None:
result = next(step.process_applying_mappings(batch.data[0]))
except Exception as e:
if self.step.is_global:
- raise _StepWrapperException(str(e), self.step, 2, e) from e
+ self.step.unload()
+ self._notify_unload()
+ data = (
+ batch.data
+ if isinstance(
+ e, DistilabelOfflineBatchGenerationNotFinishedException
+ )
+ else None
+ )
+ raise _StepWrapperException(str(e), self.step, 2, e, data) from e
# Impute step outputs columns with `None`
result = self._impute_step_outputs(batch)
@@ -257,13 +278,7 @@ def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]:
Args:
batch: The batch to impute.
"""
- result = []
- for row in batch.data[0]:
- data = row.copy()
- for output in self.step.outputs:
- data[output] = None
- result.append(data)
- return result
+ return self.step.impute_step_outputs(batch.data[0])
def _send_batch(self, batch: _Batch) -> None:
"""Sends a batch to the `output_queue`."""
@@ -284,7 +299,8 @@ class _StepWrapperException(Exception):
message: The error message.
step: The `Step` that raised the error.
code: The error code.
- subprocess_exception: The exception raised by the subprocess. Defaults to `None`.
+ subprocess_exception: The exception raised by the subprocess.
+ data: The data that caused the error. Defaults to `None`.
"""
def __init__(
@@ -292,15 +308,21 @@ def __init__(
message: str,
step: "_Step",
code: int,
- subprocess_exception: Optional[Exception] = None,
+ subprocess_exception: Exception,
+ data: Optional[List[List[Dict[str, Any]]]] = None,
) -> None:
- self.message = message
+ self.message = f"{message}\n\nFor further information visit '{DISTILABEL_DOCS_URL}api/pipeline/step_wrapper'"
self.step = step
self.code = code
self.subprocess_exception = subprocess_exception
self.formatted_traceback = "".join(
- traceback.format_exception(subprocess_exception)
+ traceback.format_exception(
+ type(subprocess_exception),
+ subprocess_exception,
+ subprocess_exception.__traceback__,
+ )
)
+ self.data = data
@classmethod
def create_load_error(
@@ -319,7 +341,7 @@ def create_load_error(
Returns:
The `_StepWrapperException` instance.
"""
- return cls(message, step, 1, subprocess_exception)
+ return cls(message, step, 1, subprocess_exception, None)
@property
def is_load_error(self) -> bool:
diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py
index a71ffdd9b2..3fdb037e14 100644
--- a/src/distilabel/pipeline/write_buffer.py
+++ b/src/distilabel/pipeline/write_buffer.py
@@ -15,7 +15,7 @@
import logging
from os import PathLike
from pathlib import Path
-from typing import Any, Dict, List, Set
+from typing import Any, Dict, List, Optional, Set
import pyarrow as pa
import pyarrow.parquet as pq
@@ -33,12 +33,21 @@ class _WriteBuffer:
is full, the content is written to a parquet file.
"""
- def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None:
+ def __init__(
+ self,
+ path: "PathLike",
+ leaf_steps: Set[str],
+ steps_cached: Optional[Dict[str, bool]] = None,
+ ) -> None:
"""
Args:
path: Folder where the files will be written, the idea
is for this path to be in the cache folder under /data.
leaf_steps: Leaf steps from either the DAG of the Pipeline.
+ steps_cached: Dictionary with the name of a step and the variable
+ use_cache. We will use this to determine whether we have to read
+ a previous parquet table to concatenate before saving the cached
+ datasets.
Raises:
ValueError: If the path is not a directory.
@@ -61,6 +70,7 @@ def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None:
}
self._buffer_last_schema = {}
self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps}
+ self._steps_cached = steps_cached or {}
self._logger = logging.getLogger("distilabel.write_buffer")
def _get_filename(self, step_name: str) -> Path:
@@ -130,14 +140,28 @@ def _write(self, step_name: str) -> None:
self._buffer_last_schema[step_name] = table.schema
else:
if not last_schema.equals(table.schema):
- new_schema = pa.unify_schemas([last_schema, table.schema])
- self._buffer_last_schema[step_name] = new_schema
- table = table.cast(new_schema)
+ if set(last_schema.names) == set(table.schema.names):
+ table = table.select(last_schema.names)
+ else:
+ new_schema = pa.unify_schemas([last_schema, table.schema])
+ self._buffer_last_schema[step_name] = new_schema
+ table = table.cast(new_schema)
next_file_number = self._buffers_last_file[step_name]
self._buffers_last_file[step_name] = next_file_number + 1
parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet"
+ if parquet_file.exists():
+ # If the file already exists, due to some error in a pipeline that was cached
+ prev_table = pq.read_table(parquet_file)
+ # If some columns differ, it means some of the step changed, we won't load the previous table
+ # NOTE: If any step has use_cache=False, we cannot assume the previous parquet file is
+ # valid, so we will overwrite the previous parquet file. Is this the best option?
+ use_cache = False not in self._steps_cached.values()
+
+ if prev_table.column_names == table.column_names and use_cache:
+ table = pa.concat_tables([prev_table, table])
+
pq.write_table(table, parquet_file)
self._logger.debug(f"Written to file '{parquet_file}'")
diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py
index bd8fde2251..cc1be59f92 100644
--- a/src/distilabel/steps/__init__.py
+++ b/src/distilabel/steps/__init__.py
@@ -21,6 +21,10 @@
StepInput,
StepResources,
)
+from distilabel.steps.clustering.dbscan import DBSCAN
+from distilabel.steps.clustering.text_clustering import TextClustering
+from distilabel.steps.clustering.umap import UMAP
+from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import CombineColumns, GroupColumns
from distilabel.steps.columns.keep import KeepColumns
@@ -29,6 +33,8 @@
from distilabel.steps.deita import DeitaFiltering
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
+from distilabel.steps.filtering.embedding import EmbeddingDedup
+from distilabel.steps.filtering.minhash import MinHashDedup
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
FormatChatGenerationDPO,
@@ -39,6 +45,7 @@
FormatTextGenerationSFT,
)
from distilabel.steps.generators.data import LoadDataFromDicts
+from distilabel.steps.generators.data_sampler import DataSampler
from distilabel.steps.generators.huggingface import (
LoadDataFromDisk,
LoadDataFromFileSystem,
@@ -47,37 +54,46 @@
from distilabel.steps.generators.utils import make_generator_step
from distilabel.steps.globals.huggingface import PushToHub
from distilabel.steps.reward_model import RewardModelScore
+from distilabel.steps.truncate import TruncateTextColumn
from distilabel.steps.typing import GeneratorStepOutput, StepOutput
__all__ = [
"PreferenceToArgilla",
"TextGenerationToArgilla",
+ "GeneratorStep",
+ "GlobalStep",
+ "Step",
+ "StepInput",
"StepResources",
+ "CombineOutputs",
+ "ExpandColumns",
+ "CombineColumns",
"GroupColumns",
+ "KeepColumns",
"MergeColumns",
- "CombineColumns",
- "ConversationTemplate",
+ "DBSCAN",
+ "UMAP",
+ "TextClustering",
+ "step",
"DeitaFiltering",
"EmbeddingGeneration",
"FaissNearestNeighbour",
- "ExpandColumns",
+ "ConversationTemplate",
"FormatChatGenerationDPO",
- "FormatChatGenerationSFT",
"FormatTextGenerationDPO",
+ "FormatChatGenerationSFT",
"FormatTextGenerationSFT",
- "GeneratorStep",
- "GlobalStep",
- "KeepColumns",
"LoadDataFromDicts",
+ "DataSampler",
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
+ "EmbeddingDedup",
+ "MinHashDedup",
"make_generator_step",
"PushToHub",
- "Step",
- "StepInput",
"RewardModelScore",
+ "TruncateTextColumn",
"GeneratorStepOutput",
"StepOutput",
- "step",
]
diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py
index 1de89d38cf..ea491e07a5 100644
--- a/src/distilabel/steps/argilla/base.py
+++ b/src/distilabel/steps/argilla/base.py
@@ -15,7 +15,7 @@
import importlib.util
import os
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, List, Optional
+from typing import TYPE_CHECKING, Any, Optional
from pydantic import Field, PrivateAttr, SecretStr
@@ -24,13 +24,14 @@
except ImportError:
pass
+from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from argilla import Argilla, Dataset
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
@@ -110,7 +111,10 @@ def _client_init(self) -> None:
else {},
)
except Exception as e:
- raise ValueError(f"Failed to initialize the Argilla API: {e}") from e
+ raise DistilabelUserError(
+ f"Failed to initialize the Argilla API: {e}",
+ page="sections/how_to_guides/advanced/argilla/",
+ ) from e
@property
def _dataset_exists_in_workspace(self) -> bool:
@@ -128,7 +132,7 @@ def _dataset_exists_in_workspace(self) -> bool:
)
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The outputs of the step is an empty list, since the steps subclassing from this one, will
always be leaf nodes and won't propagate the inputs neither generate any outputs.
"""
@@ -141,17 +145,18 @@ def load(self) -> None:
super().load()
if self.api_url is None or self.api_key is None:
- raise ValueError(
+ raise DistilabelUserError(
"`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
" provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
- " and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`."
+ " and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`.",
+ page="sections/how_to_guides/advanced/argilla/",
)
self._client_init()
@property
@abstractmethod
- def inputs(self) -> List[str]: ...
+ def inputs(self) -> "StepColumns": ...
@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput": ...
diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py
index 572c2e6746..210cca208f 100644
--- a/src/distilabel/steps/argilla/preference.py
+++ b/src/distilabel/steps/argilla/preference.py
@@ -23,6 +23,7 @@
except ImportError:
pass
+from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput
@@ -70,7 +71,6 @@ class PreferenceToArgilla(ArgillaBase):
generated rationales won't be pushed to Argilla.
Examples:
-
Push a preference dataset to an Argilla instance:
```python
@@ -166,11 +166,12 @@ def load(self) -> None:
]
and field.required
):
- raise ValueError(
+ raise DistilabelUserError(
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
f" already exists, but contains at least a required field that is"
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
- f" (one per generation starting from 0 up to {self.num_generations - 1})."
+ f" (one per generation starting from 0 up to {self.num_generations - 1}).",
+ page="components-gallery/steps/preferencetoargilla/",
)
self._dataset = _dataset
diff --git a/src/distilabel/steps/argilla/text_generation.py b/src/distilabel/steps/argilla/text_generation.py
index d3df0767c8..ad5323b0bc 100644
--- a/src/distilabel/steps/argilla/text_generation.py
+++ b/src/distilabel/steps/argilla/text_generation.py
@@ -23,6 +23,7 @@
except ImportError:
pass
+from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput
@@ -60,7 +61,6 @@ class TextGenerationToArgilla(ArgillaBase):
- generation (`str` or `List[str]`): The completions that were generated based on the input instruction.
Examples:
-
Push a text generation dataset to an Argilla instance:
```python
@@ -117,11 +117,12 @@ def load(self) -> None:
field.name not in [self._id, self._instruction, self._generation]
and field.required
):
- raise ValueError(
+ raise DistilabelUserError(
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
f" already exists, but contains at least a required field that is"
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generation}`,"
- " so it cannot be reused for this dataset."
+ " so it cannot be reused for this dataset.",
+ page="components-gallery/steps/textgenerationtoargilla/",
)
self._dataset = _dataset
diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py
index 940b05d812..b98c0e8275 100644
--- a/src/distilabel/steps/base.py
+++ b/src/distilabel/steps/base.py
@@ -17,17 +17,30 @@
import re
from abc import ABC, abstractmethod
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload
+from pathlib import Path
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ overload,
+)
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from typing_extensions import Annotated, Self
+from distilabel.errors import DistilabelTypeError, DistilabelUserError
from distilabel.mixins.requirements import RequirementsMixin
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
)
-from distilabel.utils.serialization import _Serializable
+from distilabel.mixins.signature import SignatureMixin
+from distilabel.utils.serialization import _Serializable, write_json
from distilabel.utils.typing_ import is_parameter_annotated_with
if TYPE_CHECKING:
@@ -40,7 +53,7 @@
DownstreamConnectableSteps,
UpstreamConnectableSteps,
)
- from distilabel.steps.typing import GeneratorStepOutput, StepOutput
+ from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput
DEFAULT_INPUT_BATCH_SIZE = 50
@@ -121,7 +134,14 @@ class StepResources(RuntimeParametersMixin, BaseModel):
)
-class _Step(RuntimeParametersMixin, RequirementsMixin, BaseModel, _Serializable, ABC):
+class _Step(
+ RuntimeParametersMixin,
+ RequirementsMixin,
+ SignatureMixin,
+ BaseModel,
+ _Serializable,
+ ABC,
+):
"""Base class for the steps that can be included in a `Pipeline`.
A `Step` is a class defining some processing logic. The input and outputs for this
@@ -181,7 +201,9 @@ def process(self, inputs: *StepInput) -> StepOutput:
pipeline: Any = Field(default=None, exclude=True, repr=False)
input_mappings: Dict[str, str] = {}
output_mappings: Dict[str, str] = {}
+ use_cache: bool = True
+ _pipeline_artifacts_path: Path = PrivateAttr(None)
_built_from_decorator: bool = PrivateAttr(default=False)
_logger: "Logger" = PrivateAttr(None)
@@ -369,8 +391,10 @@ def is_normal(self) -> bool:
return not self.is_generator and not self.is_global
@property
- def inputs(self) -> List[str]:
- """List of strings with the names of the columns that the step needs as input.
+ def inputs(self) -> "StepColumns":
+ """List of strings with the names of the mandatory columns that the step needs as
+ input or dictionary in which the keys are the input columns of the step and the
+ values are booleans indicating whether the column is optional or not.
Returns:
List of strings with the names of the columns that the step needs as input.
@@ -378,9 +402,10 @@ def inputs(self) -> List[str]:
return []
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""List of strings with the names of the columns that the step will produce as
- output.
+ output or dictionary in which the keys are the output columns of the step and the
+ values are booleans indicating whether the column is optional or not.
Returns:
List of strings with the names of the columns that the step will produce as
@@ -424,9 +449,10 @@ def get_process_step_input(self) -> Union[inspect.Parameter, None]:
for parameter in self.process_parameters:
if is_parameter_annotated_with(parameter, _STEP_INPUT_ANNOTATION):
if step_input_parameter is not None:
- raise TypeError(
+ raise DistilabelTypeError(
f"Step '{self.name}' should have only one parameter with type"
- " hint `StepInput`."
+ " hint `StepInput`.",
+ page="sections/how_to_guides/basic/step/#defining-custom-steps",
)
step_input_parameter = parameter
return step_input_parameter
@@ -443,10 +469,11 @@ def verify_inputs_mappings(self) -> None:
for input in self.input_mappings:
if input not in self.inputs:
- raise ValueError(
+ raise DistilabelUserError(
f"The input column '{input}' doesn't exist in the inputs of the"
f" step '{self.name}'. Inputs of the step are: {self.inputs}."
- " Please, review the `inputs_mappings` argument of the step."
+ " Please, review the `inputs_mappings` argument of the step.",
+ page="sections/how_to_guides/basic/step/#arguments",
)
def verify_outputs_mappings(self) -> None:
@@ -461,29 +488,122 @@ def verify_outputs_mappings(self) -> None:
for output in self.output_mappings:
if output not in self.outputs:
- raise ValueError(
+ raise DistilabelUserError(
f"The output column '{output}' doesn't exist in the outputs of the"
f" step '{self.name}'. Outputs of the step are: {self.outputs}."
- " Please, review the `outputs_mappings` argument of the step."
+ " Please, review the `outputs_mappings` argument of the step.",
+ page="sections/how_to_guides/basic/step/#arguments",
)
- def get_inputs(self) -> List[str]:
+ def get_inputs(self) -> Dict[str, bool]:
"""Gets the inputs of the step after the `input_mappings`. This method is meant
to be used to run validations on the inputs of the step.
Returns:
- The inputs of the step after the `input_mappings`.
+ The inputs of the step after the `input_mappings` and if they are required or
+ not.
"""
- return [self.input_mappings.get(input, input) for input in self.inputs]
+ if isinstance(self.inputs, list):
+ return {
+ self.input_mappings.get(input, input): True for input in self.inputs
+ }
+
+ return {
+ self.input_mappings.get(input, input): required
+ for input, required in self.inputs.items()
+ }
- def get_outputs(self) -> List[str]:
+ def get_outputs(self) -> Dict[str, bool]:
"""Gets the outputs of the step after the `outputs_mappings`. This method is
meant to be used to run validations on the outputs of the step.
Returns:
- The outputs of the step after the `outputs_mappings`.
+ The outputs of the step after the `outputs_mappings` and if they are required
+ or not.
+ """
+ if isinstance(self.outputs, list):
+ return {
+ self.output_mappings.get(output, output): True
+ for output in self.outputs
+ }
+
+ return {
+ self.output_mappings.get(output, output): required
+ for output, required in self.outputs.items()
+ }
+
+ def set_pipeline_artifacts_path(self, path: Path) -> None:
+ """Sets the `_pipeline_artifacts_path` attribute. This method is meant to be used
+ by the `Pipeline` once the cache location is known.
+
+ Args:
+ path: the path where the artifacts generated by the pipeline steps should be
+ saved.
+ """
+ self._pipeline_artifacts_path = path
+
+ @property
+ def artifacts_directory(self) -> Union[Path, None]:
+ """Gets the path of the directory where the step should save its generated artifacts.
+
+ Returns:
+ The path of the directory where the step should save the generated artifacts,
+ or `None` if `_pipeline_artifacts_path` is not set.
+ """
+ if self._pipeline_artifacts_path is None:
+ return None
+ return self._pipeline_artifacts_path / self.name # type: ignore
+
+ def save_artifact(
+ self,
+ name: str,
+ write_function: Callable[[Path], None],
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """Saves an artifact generated by the `Step`.
+
+ Args:
+ name: the name of the artifact.
+ write_function: a function that will receive the path where the artifact should
+ be saved.
+ metadata: the artifact metadata. Defaults to `None`.
"""
- return [self.output_mappings.get(output, output) for output in self.outputs]
+ if self.artifacts_directory is None:
+ self._logger.warning(
+ f"Cannot save artifact with '{name}' as `_pipeline_artifacts_path` is not"
+ " set. This is normal if the `Step` is being executed as a standalone component."
+ )
+ return
+
+ artifact_directory_path = self.artifacts_directory / name
+ artifact_directory_path.mkdir(parents=True, exist_ok=True)
+
+ self._logger.info(f"🏺 Storing '{name}' generated artifact...")
+
+ self._logger.debug(
+ f"Calling `write_function` to write artifact in '{artifact_directory_path}'..."
+ )
+ write_function(artifact_directory_path)
+
+ metadata_path = artifact_directory_path / "metadata.json"
+ self._logger.debug(
+ f"Calling `write_json` to write artifact metadata in '{metadata_path}'..."
+ )
+ write_json(filename=metadata_path, data=metadata or {})
+
+ def impute_step_outputs(
+ self, step_output: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ Imputes the output columns of the step that are not present in the step output.
+ """
+ result = []
+ for row in step_output:
+ data = row.copy()
+ for output in self.get_outputs().keys():
+ data[output] = None
+ result.append(data)
+ return result
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
dump = super()._model_dump(obj, **kwargs)
@@ -531,7 +651,11 @@ def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput"
The output rows.
"""
- inputs = self._apply_input_mappings(args) if self.input_mappings else args
+ inputs, overriden_inputs = (
+ self._apply_input_mappings(args)
+ if self.input_mappings
+ else (args, [{} for _ in range(len(args[0]))])
+ )
# If the `Step` was built using the `@step` decorator, then we need to pass
# the runtime parameters as kwargs, so they can be used within the processing
@@ -543,48 +667,86 @@ def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput"
)
for output_rows in generator:
- yield [
- {
- # Apply output mapping and revert input mapping
- self.output_mappings.get(k, None)
- or self.input_mappings.get(k, None)
- or k: v
- for k, v in row.items()
- }
- for row in output_rows
- ]
-
- def _revert_input_mappings(self, input: Dict[str, Any]) -> Dict[str, Any]:
- """Reverts the `input_mappings` of the step to the input row.
-
- Args:
- input: The input row.
-
- Returns:
- The input row with the `input_mappings` reverted.
- """
- return {self.input_mappings.get(k, k): v for k, v in input.items()}
+ restored = []
+ for i, row in enumerate(output_rows):
+ # Correct the index here because we don't know the num_generations from the llm
+ # ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
+ # from `num_generations==2` and `group_generations=False` in the LLM:
+ # The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
+ ntimes_i = i % len(overriden_inputs)
+ restored.append(
+ self._apply_mappings_and_restore_overriden(
+ row, overriden_inputs[ntimes_i]
+ )
+ )
+ yield restored
def _apply_input_mappings(
self, inputs: Tuple[List[Dict[str, Any]], ...]
- ) -> List[List[Dict[str, Any]]]:
+ ) -> Tuple[Tuple[List[Dict[str, Any]], ...], List[Dict[str, Any]]]:
"""Applies the `input_mappings` to the input rows.
Args:
inputs: The input rows.
Returns:
- The input rows with the `input_mappings` applied.
+ The input rows with the `input_mappings` applied and the overriden values
+ that were replaced by the `input_mappings`.
"""
reverted_input_mappings = {v: k for k, v in self.input_mappings.items()}
- return [
- [
- {reverted_input_mappings.get(k, k): v for k, v in row.items()}
- for row in row_inputs
- ]
- for row_inputs in inputs
- ]
+ renamed_inputs = []
+ overriden_inputs = []
+ for i, row_inputs in enumerate(inputs):
+ renamed_row_inputs = []
+ for row in row_inputs:
+ overriden_keys = {}
+ renamed_row = {}
+ for k, v in row.items():
+ renamed_key = reverted_input_mappings.get(k, k)
+
+ if renamed_key not in renamed_row or k != renamed_key:
+ renamed_row[renamed_key] = v
+
+ if k != renamed_key and renamed_key in row and len(inputs) == 1:
+ overriden_keys[renamed_key] = row[renamed_key]
+
+ if i == 0:
+ overriden_inputs.append(overriden_keys)
+ renamed_row_inputs.append(renamed_row)
+ renamed_inputs.append(renamed_row_inputs)
+ return tuple(renamed_inputs), overriden_inputs
+
+ def _apply_mappings_and_restore_overriden(
+ self, row: Dict[str, Any], overriden: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Reverts the `input_mappings` applied to the input rows and applies the `output_mappings`
+ to the output rows. In addition, it restores the overriden values that were replaced
+ by the `input_mappings`.
+
+ Args:
+ row: The output row.
+ overriden: The overriden values that were replaced by the `input_mappings`.
+
+ Returns:
+ The output row with the `output_mappings` applied and the overriden values
+ restored.
+ """
+ result = {}
+ for k, v in row.items():
+ mapped_key = (
+ self.output_mappings.get(k, None)
+ or self.input_mappings.get(k, None)
+ or k
+ )
+ result[mapped_key] = v
+
+ # Restore overriden values
+ for k, v in overriden.items():
+ if k not in result:
+ result[k] = v
+
+ return result
class GeneratorStep(_Step, ABC):
@@ -658,9 +820,9 @@ class GlobalStep(Step, ABC):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
return []
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return []
diff --git a/src/distilabel/steps/constants.py b/src/distilabel/steps/clustering/__init__.py
similarity index 87%
rename from src/distilabel/steps/constants.py
rename to src/distilabel/steps/clustering/__init__.py
index 259780e9fd..20ce00bda7 100644
--- a/src/distilabel/steps/constants.py
+++ b/src/distilabel/steps/clustering/__init__.py
@@ -12,6 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Final
-
-DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata"
diff --git a/src/distilabel/steps/clustering/dbscan.py b/src/distilabel/steps/clustering/dbscan.py
new file mode 100644
index 0000000000..03ac5dcb3e
--- /dev/null
+++ b/src/distilabel/steps/clustering/dbscan.py
@@ -0,0 +1,177 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+from typing import TYPE_CHECKING, Any, List, Optional
+
+import numpy as np
+from pydantic import Field, PrivateAttr
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps import (
+ GlobalStep,
+ StepInput,
+)
+
+if TYPE_CHECKING:
+ from sklearn.cluster import DBSCAN as _DBSCAN
+
+ from distilabel.steps.typing import StepOutput
+
+
+class DBSCAN(GlobalStep):
+ r"""DBSCAN (Density-Based Spatial Clustering of Applications with Noise) finds core
+ samples in regions of high density and expands clusters from them. This algorithm
+ is good for data which contains clusters of similar density.
+
+ This is a `GlobalStep` that clusters the embeddings using the DBSCAN algorithm
+ from `sklearn`. Visit `TextClustering` step for an example of use.
+ The trained model is saved as an artifact when creating a distiset
+ and pushing it to the Hugging Face Hub.
+
+ Input columns:
+ - projection (`List[float]`): Vector representation of the text to cluster,
+ normally the output from the `UMAP` step.
+
+ Output columns:
+ - cluster_label (`int`): Integer representing the label of a given cluster. -1
+ means it wasn't clustered.
+
+ Categories:
+ - clustering
+ - text-classification
+
+ References:
+ - [`DBSCAN demo of sklearn`](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#demo-of-dbscan-clustering-algorithm)
+ - [`sklearn dbscan`](https://scikit-learn.org/stable/modules/clustering.html#dbscan)
+
+ Attributes:
+ - eps: The maximum distance between two samples for one to be considered as in the
+ neighborhood of the other. This is not a maximum bound on the distances of
+ points within a cluster. This is the most important DBSCAN parameter to
+ choose appropriately for your data set and distance function.
+ - min_samples: The number of samples (or total weight) in a neighborhood for a point
+ to be considered as a core point. This includes the point itself. If `min_samples`
+ is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
+ to a lower value, the found clusters will be more sparse.
+ - metric: The metric to use when calculating distance between instances in a feature
+ array. If metric is a string or callable, it must be one of the options allowed
+ by `sklearn.metrics.pairwise_distances` for its metric parameter.
+ - n_jobs: The number of parallel jobs to run.
+
+ Runtime parameters:
+ - `eps`: The maximum distance between two samples for one to be considered as in the
+ neighborhood of the other. This is not a maximum bound on the distances of
+ points within a cluster. This is the most important DBSCAN parameter to
+ choose appropriately for your data set and distance function.
+ - `min_samples`: The number of samples (or total weight) in a neighborhood for a point
+ to be considered as a core point. This includes the point itself. If `min_samples`
+ is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
+ to a lower value, the found clusters will be more sparse.
+ - `metric`: The metric to use when calculating distance between instances in a feature
+ array. If metric is a string or callable, it must be one of the options allowed
+ by `sklearn.metrics.pairwise_distances` for its metric parameter.
+ - `n_jobs`: The number of parallel jobs to run.
+ """
+
+ eps: Optional[RuntimeParameter[float]] = Field(
+ default=0.3,
+ description=(
+ "The maximum distance between two samples for one to be considered "
+ "as in the neighborhood of the other. This is not a maximum bound "
+ "on the distances of points within a cluster. This is the most "
+ "important DBSCAN parameter to choose appropriately for your data set "
+ "and distance function."
+ ),
+ )
+ min_samples: Optional[RuntimeParameter[int]] = Field(
+ default=30,
+ description=(
+ "The number of samples (or total weight) in a neighborhood for a point to "
+ "be considered as a core point. This includes the point itself. If "
+ "`min_samples` is set to a higher value, DBSCAN will find denser clusters, "
+ "whereas if it is set to a lower value, the found clusters will be more "
+ "sparse."
+ ),
+ )
+ metric: Optional[RuntimeParameter[str]] = Field(
+ default="euclidean",
+ description=(
+ "The metric to use when calculating distance between instances in a "
+ "feature array. If metric is a string or callable, it must be one of "
+ "the options allowed by `sklearn.metrics.pairwise_distances` for "
+ "its metric parameter."
+ ),
+ )
+ n_jobs: Optional[RuntimeParameter[int]] = Field(
+ default=8, description="The number of parallel jobs to run."
+ )
+
+ _clusterer: Optional["_DBSCAN"] = PrivateAttr(None)
+
+ def load(self) -> None:
+ super().load()
+ if importlib.util.find_spec("sklearn") is None:
+ raise ImportError(
+ "`sklearn` package is not installed. Please install it using `pip install scikit-learn`."
+ )
+ from sklearn.cluster import DBSCAN as _DBSCAN
+
+ self._clusterer = _DBSCAN(
+ eps=self.eps,
+ min_samples=self.min_samples,
+ metric=self.metric,
+ n_jobs=self.n_jobs,
+ )
+
+ def unload(self) -> None:
+ self._clusterer = None
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["projection"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["cluster_label"]
+
+ def _save_model(self, model: Any) -> None:
+ import joblib
+
+ def save_model(path):
+ with open(str(path / "DBSCAN.joblib"), "wb") as f:
+ joblib.dump(model, f)
+
+ self.save_artifact(
+ name="DBSCAN_model",
+ write_function=lambda path: save_model(path),
+ metadata={
+ "eps": self.eps,
+ "min_samples": self.min_samples,
+ "metric": self.metric,
+ },
+ )
+
+ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
+ projections = np.array([input["projection"] for input in inputs])
+
+ self._logger.info("🏋️♀️ Start training DBSCAN...")
+ fitted_clusterer = self._clusterer.fit(projections)
+ cluster_labels = fitted_clusterer.labels_
+ # Sets the cluster labels for each input, -1 means it wasn't clustered
+ for input, cluster_label in zip(inputs, cluster_labels):
+ input["cluster_label"] = cluster_label
+ self._logger.info(f"DBSCAN labels assigned: {len(set(cluster_labels))}")
+ self._save_model(fitted_clusterer)
+ yield inputs
diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py
new file mode 100644
index 0000000000..7e640bf5c1
--- /dev/null
+++ b/src/distilabel/steps/clustering/text_clustering.py
@@ -0,0 +1,326 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+import json
+from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pandas as pd
+from pydantic import Field
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps import StepInput
+from distilabel.steps.tasks import TextClassification
+from distilabel.steps.tasks.base import GlobalTask
+from distilabel.utils.itertools import batched
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
+
+
+class TextClustering(TextClassification, GlobalTask):
+ """Task that clusters a set of texts and generates summary labels for each cluster.
+
+ This is a `GlobalTask` that inherits from `TextClassification`, this means that all
+ the attributes from that class are available here. Also, in this case we deal
+ with all the inputs at once, instead of using batches. The `input_batch_size` is
+ used here to send the examples to the LLM in batches (a subtle difference with the
+ more common `Task` definitions).
+ The task looks in each cluster for a given number of representative examples (the number
+ is set by the `samples_per_cluster` attribute), and sends them to the LLM to get a label/s
+ that represent the cluster. The labels are then assigned to each text in the cluster.
+ The clusters and projections used in the step, are assumed to be obtained from the `UMAP`
+ + `DBSCAN` steps, but could be generated for similar steps, as long as they represent the
+ same concepts.
+ This step runs a pipeline like the one in this repository:
+ https://github.com/huggingface/text-clustering
+
+ Input columns:
+ - text (`str`): The reference text we want to obtain labels for.
+ - projection (`List[float]`): Vector representation of the text to cluster,
+ normally the output from the `UMAP` step.
+ - cluster_label (`int`): Integer representing the label of a given cluster. -1
+ means it wasn't clustered.
+
+ Output columns:
+ - summary_label (`str`): The label or list of labels for the text.
+ - model_name (`str`): The name of the model used to generate the label/s.
+
+ Categories:
+ - clustering
+ - text-classification
+
+ References:
+ - [`text-clustering repository`](https://github.com/huggingface/text-clustering)
+
+ Attributes:
+ - savefig: Whether to generate and save a figure with the clustering of the texts.
+ - samples_per_cluster: The number of examples to use in the LLM as a sample of the cluster.
+
+ Examples:
+ Generate labels for a set of texts using clustering:
+
+ ```python
+ from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.steps import UMAP, DBSCAN, TextClustering
+ from distilabel.pipeline import Pipeline
+
+ ds_name = "argilla-warehouse/personahub-fineweb-edu-4-clustering-100k"
+
+ with Pipeline(name="Text clustering dataset") as pipeline:
+ batch_size = 500
+
+ ds = load_dataset(ds_name, split="train").select(range(10000))
+ loader = make_generator_step(ds, batch_size=batch_size, repo_id=ds_name)
+
+ umap = UMAP(n_components=2, metric="cosine")
+ dbscan = DBSCAN(eps=0.3, min_samples=30)
+
+ text_clustering = TextClustering(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ n=3, # 3 labels per example
+ query_title="Examples of Personas",
+ samples_per_cluster=10,
+ context=(
+ "Describe the main themes, topics, or categories that could describe the "
+ "following types of personas. All the examples of personas must share "
+ "the same set of labels."
+ ),
+ default_label="None",
+ savefig=True,
+ input_batch_size=8,
+ input_mappings={"text": "persona"},
+ use_default_structured_output=True,
+ )
+
+ loader >> umap >> dbscan >> text_clustering
+ ```
+ """
+
+ savefig: Optional[RuntimeParameter[bool]] = Field(
+ default=True,
+ description="Whether to generate and save a figure with the clustering of the texts.",
+ )
+ samples_per_cluster: int = Field(
+ default=10,
+ description="The number of examples to use in the LLM as a sample of the cluster.",
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ """The input for the task are the same as those for `TextClassification` plus
+ the `projection` and `cluster_label` columns (which can be obtained from
+ UMAP + DBSCAN steps).
+ """
+ return super().inputs + ["projection", "cluster_label"]
+
+ @property
+ def outputs(self) -> List[str]:
+ """The output for the task is the `summary_label` and the `model_name`."""
+ return ["summary_label", "model_name"]
+
+ def load(self) -> None:
+ super().load()
+ if self.savefig and (importlib.util.find_spec("matplotlib") is None):
+ raise ImportError(
+ "`matplotlib` package is not installed. Please install it using `pip install matplotlib`."
+ )
+
+ def _save_figure(
+ self,
+ data: pd.DataFrame,
+ cluster_centers: Dict[str, Tuple[float, float]],
+ cluster_summaries: Dict[int, str],
+ ) -> None:
+ """Saves the figure starting from the dataframe, using matplotlib.
+
+ Args:
+ data: pd.DataFrame with the columns 'X', 'Y' and 'labels' representing
+ the projections and the label of each text respectively.
+ cluster_centers: Dictionary mapping from each label the center of a cluster,
+ to help with the placement of the annotations.
+ cluster_summaries: The summaries of the clusters, obtained from the LLM.
+ """
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(12, 8), dpi=300)
+ unique_labels = data["labels"].unique()
+ # Map of colors for each label (-1 is black)
+ colormap = dict(
+ zip(unique_labels, plt.cm.Spectral(np.linspace(0, 1, len(unique_labels))))
+ )
+ colormap[-1] = np.array([0, 0, 0, 0])
+ data["color"] = data["labels"].map(colormap)
+
+ data.plot(
+ kind="scatter",
+ x="X",
+ y="Y",
+ c="color",
+ s=0.75,
+ alpha=0.8,
+ linewidth=0.4,
+ ax=ax,
+ colorbar=False,
+ )
+
+ for label in cluster_summaries.keys():
+ if label == -1:
+ continue
+ summary = str(cluster_summaries[label]) # These are obtained from the LLM
+ position = cluster_centers[label]
+ t = ax.text(
+ position[0],
+ position[1],
+ summary,
+ horizontalalignment="center",
+ verticalalignment="center",
+ fontsize=4,
+ )
+ t.set_bbox(
+ {
+ "facecolor": "white",
+ "alpha": 0.9,
+ "linewidth": 0,
+ "boxstyle": "square,pad=0.1",
+ }
+ )
+
+ ax.set_axis_off()
+ # Save the plot as an artifact of the step
+ self.save_artifact(
+ name="Text clusters",
+ write_function=lambda path: fig.savefig(path / "figure_clustering.png"),
+ metadata={"type": "image", "library": "matplotlib"},
+ )
+ plt.close()
+
+ def _create_figure(
+ self,
+ inputs: StepInput,
+ label2docs: Dict[int, List[str]],
+ cluster_summaries: Dict[int, str],
+ ) -> None:
+ """Creates a figure of the clustered texts and save it as an artifact.
+
+ Args:
+ inputs: The inputs of the step, as we will extract information from them again.
+ label2docs: Map from each label to the list of documents (texts) that belong to that cluster.
+ cluster_summaries: The summaries of the clusters, obtained from the LLM.
+ """
+ self._logger.info("🖼️ Creating figure for the clusters...")
+
+ labels = []
+ projections = []
+ id2cluster = {}
+ for i, input in enumerate(inputs):
+ label = input["cluster_label"]
+ id2cluster[i] = label
+ labels.append(label)
+ projections.append(input["projection"])
+
+ projections = np.array(projections)
+
+ # Contains the placement of the cluster centers in the figure
+ cluster_centers: Dict[str, Tuple[float, float]] = {}
+ for label in label2docs.keys():
+ x = np.mean([projections[doc, 0] for doc in label2docs[label]])
+ y = np.mean([projections[doc, 1] for doc in label2docs[label]])
+ cluster_centers[label] = (x, y)
+
+ df = pd.DataFrame(
+ data={
+ "X": projections[:, 0],
+ "Y": projections[:, 1],
+ "labels": labels,
+ }
+ )
+
+ self._save_figure(
+ df, cluster_centers=cluster_centers, cluster_summaries=cluster_summaries
+ )
+
+ def _prepare_input_texts(
+ self,
+ inputs: StepInput,
+ label2docs: Dict[int, List[int]],
+ unique_labels: List[int],
+ ) -> List[Dict[str, Union[str, int]]]:
+ """Prepares a batch of inputs to send to the LLM, with the examples of each cluster.
+
+ Args:
+ inputs: Inputs from the step.
+ label2docs: Map from each label to the list of documents (texts) that
+ belong to that cluster.
+ unique_labels: The unique labels of the clusters.
+
+ Returns:
+ The input texts to send to the LLM, with the examples of each cluster
+ prepared to be used in the prompt, and an additional key to store the
+ labels (that will be needed to find the data after the batches are
+ returned from the LLM).
+ """
+ input_texts = []
+ for label in range(unique_labels): # The label -1 is implicitly excluded
+ # Get the ids but remove possible duplicates, which could happen with bigger probability
+ # the bigger the number of examples requested, and the smaller the subset of examples
+ ids = set(
+ np.random.choice(label2docs[label], size=self.samples_per_cluster)
+ ) # Grab the number of examples
+ examples = [inputs[i]["text"] for i in ids]
+ input_text = {
+ "text": "\n\n".join(
+ [f"Example {i}:\n{t}" for i, t in enumerate(examples, start=1)]
+ ),
+ "__LABEL": label,
+ }
+ input_texts.append(input_text)
+ return input_texts
+
+ def process(self, inputs: StepInput) -> "StepOutput":
+ labels = [input["cluster_label"] for input in inputs]
+ # -1 because -1 is the label for the unclassified
+ unique_labels = len(set(labels)) - 1
+ # This will be the output of the LLM, the set of labels for each cluster
+ cluster_summaries: Dict[int, str] = {-1: self.default_label}
+
+ # Map from label to list of documents, will use them to select examples from each cluster
+ label2docs = defaultdict(list)
+ for i, label in enumerate(labels):
+ label2docs[label].append(i)
+
+ input_texts = self._prepare_input_texts(inputs, label2docs, unique_labels)
+
+ # Send the texts in batches to the LLM, and get the labels for each cluster
+ for i, batched_inputs in enumerate(batched(input_texts, self.input_batch_size)):
+ self._logger.info(f"📦 Processing internal batch of inputs {i}...")
+ results = super().process(batched_inputs)
+ for result in next(results): # Extract the elements from the generator
+ cluster_summaries[result["__LABEL"]] = result["labels"]
+
+ # Assign the labels to each text
+ for input in inputs:
+ input["summary_label"] = json.dumps(
+ cluster_summaries[input["cluster_label"]]
+ )
+
+ if self.savefig:
+ self._create_figure(inputs, label2docs, cluster_summaries)
+
+ yield inputs
diff --git a/src/distilabel/steps/clustering/umap.py b/src/distilabel/steps/clustering/umap.py
new file mode 100644
index 0000000000..daeb37486d
--- /dev/null
+++ b/src/distilabel/steps/clustering/umap.py
@@ -0,0 +1,164 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+from typing import TYPE_CHECKING, Any, List, Optional
+
+import numpy as np
+from pydantic import Field, PrivateAttr
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps import (
+ GlobalStep,
+ StepInput,
+)
+
+if TYPE_CHECKING:
+ from umap import UMAP as _UMAP
+
+ from distilabel.steps.typing import StepOutput
+
+
+class UMAP(GlobalStep):
+ r"""UMAP is a general purpose manifold learning and dimension reduction algorithm.
+
+ This is a `GlobalStep` that reduces the dimensionality of the embeddings using. Visit
+ the `TextClustering` step for an example of use. The trained model is saved as an artifact
+ when creating a distiset and pushing it to the Hugging Face Hub.
+
+ Input columns:
+ - embedding (`List[float]`): The original embeddings we want to reduce the dimension.
+
+ Output columns:
+ - projection (`List[float]`): Embedding reduced to the number of components specified,
+ the size of the new embeddings will be determined by the `n_components`.
+
+ Categories:
+ - clustering
+ - text-classification
+
+ References:
+ - [`UMAP repository`](https://github.com/lmcinnes/umap/tree/master)
+ - [`UMAP documentation`](https://umap-learn.readthedocs.io/en/latest/)
+
+ Attributes:
+ - n_components: The dimension of the space to embed into. This defaults to 2 to
+ provide easy visualization (that's probably what you want), but can
+ reasonably be set to any integer value in the range 2 to 100.
+ - metric: The metric to use to compute distances in high dimensional space.
+ Visit UMAP's documentation for more information. Defaults to `euclidean`.
+ - n_jobs: The number of parallel jobs to run. Defaults to `8`.
+ - random_state: The random state to use for the UMAP algorithm.
+
+ Runtime parameters:
+ - `n_components`: The dimension of the space to embed into. This defaults to 2 to
+ provide easy visualization (that's probably what you want), but can
+ reasonably be set to any integer value in the range 2 to 100.
+ - `metric`: The metric to use to compute distances in high dimensional space.
+ Visit UMAP's documentation for more information. Defaults to `euclidean`.
+ - `n_jobs`: The number of parallel jobs to run. Defaults to `8`.
+ - `random_state`: The random state to use for the UMAP algorithm.
+
+ Citations:
+ ```
+ @misc{mcinnes2020umapuniformmanifoldapproximation,
+ title={UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction},
+ author={Leland McInnes and John Healy and James Melville},
+ year={2020},
+ eprint={1802.03426},
+ archivePrefix={arXiv},
+ primaryClass={stat.ML},
+ url={https://arxiv.org/abs/1802.03426},
+ }
+ ```
+ """
+
+ n_components: Optional[RuntimeParameter[int]] = Field(
+ default=2,
+ description=(
+ "The dimension of the space to embed into. This defaults to 2 to "
+ "provide easy visualization, but can reasonably be set to any "
+ "integer value in the range 2 to 100."
+ ),
+ )
+ metric: Optional[RuntimeParameter[str]] = Field(
+ default="euclidean",
+ description=(
+ "The metric to use to compute distances in high dimensional space. "
+ "Visit UMAP's documentation for more information."
+ ),
+ )
+ n_jobs: Optional[RuntimeParameter[int]] = Field(
+ default=8, description="The number of parallel jobs to run."
+ )
+ random_state: Optional[RuntimeParameter[int]] = Field(
+ default=None, description="The random state to use for the UMAP algorithm."
+ )
+
+ _umap: Optional["_UMAP"] = PrivateAttr(None)
+
+ def load(self) -> None:
+ super().load()
+ if importlib.util.find_spec("umap") is None:
+ raise ImportError(
+ "`umap` package is not installed. Please install it using `pip install umap-learn`."
+ )
+ from umap import UMAP as _UMAP
+
+ self._umap = _UMAP(
+ n_components=self.n_components,
+ metric=self.metric,
+ n_jobs=self.n_jobs,
+ random_state=self.random_state,
+ )
+
+ def unload(self) -> None:
+ self._umap = None
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["embedding"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["projection"]
+
+ def _save_model(self, model: Any) -> None:
+ import joblib
+
+ def save_model(path):
+ with open(str(path / "UMAP.joblib"), "wb") as f:
+ joblib.dump(model, f)
+
+ self.save_artifact(
+ name="UMAP_model",
+ write_function=lambda path: save_model(path),
+ metadata={
+ "n_components": self.n_components,
+ "metric": self.metric,
+ },
+ )
+
+ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
+ # Shape of the embeddings is (n_samples, n_features)
+ embeddings = np.array([input["embedding"] for input in inputs])
+
+ self._logger.info("🏋️♀️ Start UMAP training...")
+ mapper = self._umap.fit(embeddings)
+ # Shape of the projection will be (n_samples, n_components)
+ for input, projection in zip(inputs, mapper.embedding_):
+ input["projection"] = projection
+
+ self._save_model(mapper)
+ yield inputs
diff --git a/src/distilabel/steps/columns/combine.py b/src/distilabel/steps/columns/combine.py
new file mode 100644
index 0000000000..784beffe47
--- /dev/null
+++ b/src/distilabel/steps/columns/combine.py
@@ -0,0 +1,99 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from distilabel.constants import DISTILABEL_METADATA_KEY
+from distilabel.steps.base import Step, StepInput
+from distilabel.steps.columns.utils import merge_distilabel_metadata
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
+
+
+class CombineOutputs(Step):
+ """Combine the outputs of several upstream steps.
+
+ `CombineOutputs` is a `Step` that takes the outputs of several upstream steps and combines
+ them to generate a new dictionary with all keys/columns of the upstream steps outputs.
+
+ Input columns:
+ - dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.
+
+ Output columns:
+ - dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.
+
+ Categories:
+ - columns
+
+ Examples:
+
+ Combine dictionaries of a dataset:
+
+ ```python
+ from distilabel.steps import CombineOutputs
+
+ combine_outputs = CombineOutputs()
+ combine_outputs.load()
+
+ result = next(
+ combine_outputs.process(
+ [{"a": 1, "b": 2}, {"a": 3, "b": 4}],
+ [{"c": 5, "d": 6}, {"c": 7, "d": 8}],
+ )
+ )
+ # [
+ # {"a": 1, "b": 2, "c": 5, "d": 6},
+ # {"a": 3, "b": 4, "c": 7, "d": 8},
+ # ]
+ ```
+
+ Combine upstream steps outputs in a pipeline:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps import CombineOutputs
+
+ with Pipeline() as pipeline:
+ step_1 = ...
+ step_2 = ...
+ step_3 = ...
+ combine = CombineOutputs()
+
+ [step_1, step_2, step_3] >> combine
+ ```
+ """
+
+ def process(self, *inputs: StepInput) -> "StepOutput":
+ combined_outputs = []
+ for output_dicts in zip(*inputs):
+ combined_dict = {}
+ for output_dict in output_dicts:
+ combined_dict.update(
+ {
+ k: v
+ for k, v in output_dict.items()
+ if k != DISTILABEL_METADATA_KEY
+ }
+ )
+
+ if any(
+ DISTILABEL_METADATA_KEY in output_dict for output_dict in output_dicts
+ ):
+ combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
+ *output_dicts
+ )
+ combined_outputs.append(combined_dict)
+
+ yield combined_outputs
diff --git a/src/distilabel/steps/columns/expand.py b/src/distilabel/steps/columns/expand.py
index 7312e1a4fd..709ca4bc66 100644
--- a/src/distilabel/steps/columns/expand.py
+++ b/src/distilabel/steps/columns/expand.py
@@ -20,7 +20,7 @@
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class ExpandColumns(Step):
@@ -42,8 +42,10 @@ class ExpandColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The expanded columns.
- Examples:
+ Categories:
+ - columns
+ Examples:
Expand the selected columns into multiple rows:
```python
@@ -87,12 +89,12 @@ def always_dict(cls, value: Union[Dict[str, str], List[str]]) -> Dict[str, str]:
return value
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The columns to be expanded."""
return list(self.columns.keys())
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The expanded columns."""
return [
new_column if new_column else expand_column
diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py
index ff761b07a1..876af1f0ad 100644
--- a/src/distilabel/steps/columns/group.py
+++ b/src/distilabel/steps/columns/group.py
@@ -17,11 +17,11 @@
from typing_extensions import override
-from distilabel.pipeline.utils import group_columns
from distilabel.steps.base import Step, StepInput
+from distilabel.steps.columns.utils import group_columns
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class GroupColumns(Step):
@@ -43,9 +43,12 @@ class GroupColumns(Step):
- dynamic (determined by `columns` and `output_columns` attributes): The columns
that were grouped.
+ Categories:
+ - columns
+
Examples:
- Combine columns of a dataset:
+ Group columns of a dataset:
```python
from distilabel.steps import GroupColumns
@@ -93,12 +96,12 @@ class GroupColumns(Step):
output_columns: Optional[List[str]] = None
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The inputs for the task are the column names in `columns`."""
return self.columns
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The outputs for the task are the column names in `output_columns` or
`grouped_{column}` for each column in `columns`."""
return (
@@ -125,6 +128,8 @@ def process(self, *inputs: StepInput) -> "StepOutput":
class CombineColumns(GroupColumns):
+ """`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead."""
+
def __init__(self, **data: Any) -> None:
warnings.warn(
"`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead.",
diff --git a/src/distilabel/steps/columns/keep.py b/src/distilabel/steps/columns/keep.py
index 58380660fa..c12dfdd61d 100644
--- a/src/distilabel/steps/columns/keep.py
+++ b/src/distilabel/steps/columns/keep.py
@@ -19,7 +19,7 @@
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class KeepColumns(Step):
@@ -44,8 +44,10 @@ class KeepColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The columns that were kept.
- Examples:
+ Categories:
+ - columns
+ Examples:
Select the columns to keep:
```python
@@ -69,12 +71,12 @@ class KeepColumns(Step):
columns: List[str]
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The inputs for the task are the column names in `columns`."""
return self.columns
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The outputs for the task are the column names in `columns`."""
return self.columns
diff --git a/src/distilabel/steps/columns/merge.py b/src/distilabel/steps/columns/merge.py
index 390de687d0..54ab3e3c75 100644
--- a/src/distilabel/steps/columns/merge.py
+++ b/src/distilabel/steps/columns/merge.py
@@ -16,11 +16,11 @@
from typing_extensions import override
-from distilabel.pipeline.utils import merge_columns
from distilabel.steps.base import Step, StepInput
+from distilabel.steps.columns.utils import merge_columns
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class MergeColumns(Step):
@@ -47,8 +47,10 @@ class MergeColumns(Step):
- dynamic (determined by `columns` and `output_column` attributes): The columns
that were merged.
- Examples:
+ Categories:
+ - columns
+ Examples:
Combine columns in rows of a dataset:
```python
@@ -79,11 +81,11 @@ class MergeColumns(Step):
output_column: Optional[str] = None
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
return self.columns
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return [self.output_column] if self.output_column else ["merged_column"]
@override
diff --git a/src/distilabel/pipeline/utils.py b/src/distilabel/steps/columns/utils.py
similarity index 64%
rename from src/distilabel/pipeline/utils.py
rename to src/distilabel/steps/columns/utils.py
index 5758053a51..7b3efe2262 100644
--- a/src/distilabel/pipeline/utils.py
+++ b/src/distilabel/steps/columns/utils.py
@@ -12,16 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
-from distilabel.steps.base import StepInput
+from distilabel.constants import DISTILABEL_METADATA_KEY
+
+if TYPE_CHECKING:
+ from distilabel.steps.base import StepInput
+
+
+def merge_distilabel_metadata(*output_dicts: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries.
+
+ Args:
+ *output_dicts: Variable number of dictionaries containing distilabel metadata.
+
+ Returns:
+ A merged dictionary containing all the distilabel metadata from the input dictionaries.
+ """
+ merged_metadata = defaultdict(list)
+
+ for output_dict in output_dicts:
+ metadata = output_dict.get(DISTILABEL_METADATA_KEY, {})
+ for key, value in metadata.items():
+ merged_metadata[key].append(value)
+
+ final_metadata = {}
+ for key, value_list in merged_metadata.items():
+ if len(value_list) == 1:
+ final_metadata[key] = value_list[0]
+ else:
+ final_metadata[key] = value_list
+
+ return final_metadata
def group_columns(
- *inputs: StepInput,
+ *inputs: "StepInput",
group_columns: List[str],
output_group_columns: Optional[List[str]] = None,
-) -> StepInput:
+) -> "StepInput":
"""Groups multiple list of dictionaries into a single list of dictionaries on the
specified `group_columns`. If `group_columns` are provided, then it will also rename
`group_columns`.
@@ -39,7 +70,7 @@ def group_columns(
group_columns
):
raise ValueError(
- "The length of output_group_columns must be the same as the length of group_columns"
+ "The length of `output_group_columns` must be the same as the length of `group_columns`."
)
if output_group_columns is None:
output_group_columns = [f"grouped_{key}" for key in group_columns]
@@ -49,16 +80,30 @@ def group_columns(
# Use zip to iterate over lists based on their index
for dicts_at_index in zip(*inputs):
combined_dict = {}
+ metadata_dicts = []
# Iterate over dicts at the same index
for d in dicts_at_index:
+ # Extract metadata for merging
+ if DISTILABEL_METADATA_KEY in d:
+ metadata_dicts.append(
+ {DISTILABEL_METADATA_KEY: d[DISTILABEL_METADATA_KEY]}
+ )
# Iterate over key-value pairs in each dict
for key, value in d.items():
+ if key == DISTILABEL_METADATA_KEY:
+ continue
# If the key is in the merge_keys, append the value to the existing list
if key in group_columns_dict.keys():
combined_dict.setdefault(group_columns_dict[key], []).append(value)
# If the key is not in the merge_keys, create a new key-value pair
else:
combined_dict[key] = value
+
+ if metadata_dicts:
+ combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
+ *metadata_dicts
+ )
+
result.append(combined_dict)
return result
diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py
index da2cbb8dcc..0816ca13eb 100644
--- a/src/distilabel/steps/decorator.py
+++ b/src/distilabel/steps/decorator.py
@@ -37,7 +37,7 @@
if TYPE_CHECKING:
from distilabel.steps.base import _Step
- from distilabel.steps.typing import GeneratorStepOutput, StepOutput
+ from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput
_STEP_MAPPING = {
"normal": Step,
@@ -50,16 +50,16 @@
@overload
def step(
- inputs: Union[List[str], None] = None,
- outputs: Union[List[str], None] = None,
+ inputs: Union["StepColumns", None] = None,
+ outputs: Union["StepColumns", None] = None,
step_type: Literal["normal"] = "normal",
) -> Callable[..., Type["Step"]]: ...
@overload
def step(
- inputs: Union[List[str], None] = None,
- outputs: Union[List[str], None] = None,
+ inputs: Union["StepColumns", None] = None,
+ outputs: Union["StepColumns", None] = None,
step_type: Literal["global"] = "global",
) -> Callable[..., Type["GlobalStep"]]: ...
@@ -67,26 +67,29 @@ def step(
@overload
def step(
inputs: None = None,
- outputs: Union[List[str], None] = None,
+ outputs: Union["StepColumns", None] = None,
step_type: Literal["generator"] = "generator",
) -> Callable[..., Type["GeneratorStep"]]: ...
def step(
- inputs: Union[List[str], None] = None,
- outputs: Union[List[str], None] = None,
+ inputs: Union["StepColumns", None] = None,
+ outputs: Union["StepColumns", None] = None,
step_type: Literal["normal", "global", "generator"] = "normal",
) -> Callable[..., Type["_Step"]]:
"""Creates an `Step` from a processing function.
Args:
- inputs: a list containing the name of the inputs columns/keys expected by this step.
- If not provided the default will be an empty list `[]` and it will be assumed
- that the step doesn't need any specific columns. Defaults to `None`.
- outputs: a list containing the name of the outputs columns/keys that the step
- will generate. If not provided the default will be an empty list `[]` and it
- will be assumed that the step doesn't need any specific columns. Defaults to
- `None`.
+ inputs: a list containing the name of the inputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column is required or not, that are required by the step. If not provided
+ the default will be an empty list `[]` and it will be assumed that the step
+ doesn't need any specific columns. Defaults to `None`.
+ outputs: a list containing the name of the outputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column will be generated or not. If not provided the default will be an
+ empty list `[]` and it will be assumed that the step doesn't need any specific
+ columns. Defaults to `None`.
step_type: the kind of step to create. Valid choices are: "normal" (`Step`),
"global" (`GlobalStep`) or "generator" (`GeneratorStep`). Defaults to
`"normal"`.
diff --git a/src/distilabel/steps/deita.py b/src/distilabel/steps/deita.py
index 5d98355c53..f817a4c263 100644
--- a/src/distilabel/steps/deita.py
+++ b/src/distilabel/steps/deita.py
@@ -66,7 +66,6 @@ class DeitaFiltering(GlobalStep):
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
Examples:
-
Filter the dataset based on the DEITA score and the cosine distance between the embeddings:
```python
@@ -102,7 +101,6 @@ class DeitaFiltering(GlobalStep):
```
Citations:
-
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py
index 55e8838274..8db3bee2ee 100644
--- a/src/distilabel/steps/embeddings/embedding_generation.py
+++ b/src/distilabel/steps/embeddings/embedding_generation.py
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING
from distilabel.embeddings.base import Embeddings
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class EmbeddingGeneration(Step):
@@ -36,8 +36,10 @@ class EmbeddingGeneration(Step):
Output columns:
- embedding (`List[Union[float, int]]`): the generated sentence embedding.
- Examples:
+ Categories:
+ - embedding
+ Examples:
Generate sentence embeddings with Sentence Transformers:
```python
@@ -61,11 +63,11 @@ class EmbeddingGeneration(Step):
embeddings: Embeddings
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
return ["text"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
return ["embedding", "model_name"]
def load(self) -> None:
diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py
index cba8e293da..98b646d9ee 100644
--- a/src/distilabel/steps/embeddings/nearest_neighbour.py
+++ b/src/distilabel/steps/embeddings/nearest_neighbour.py
@@ -46,6 +46,8 @@ class FaissNearestNeighbour(GlobalStep):
search_batch_size: the number of rows to include in a search batch. The value can
be adjusted to maximize the resources usage or to avoid OOM issues. Defaults
to `50`.
+ train_size: If the index needs a training step, specifies how many vectors will be
+ used to train the index.
Runtime parameters:
- `device`: the CUDA device ID or a list of IDs to be used. If negative integer,
@@ -60,6 +62,8 @@ class FaissNearestNeighbour(GlobalStep):
- `search_batch_size`: the number of rows to include in a search batch. The value
can be adjusted to maximize the resources usage or to avoid OOM issues. Defaults
to `50`.
+ - `train_size`: If the index needs a training step, specifies how many vectors will
+ be used to train the index.
Input columns:
- embedding (`List[Union[float, int]]`): a sentence embedding.
@@ -77,7 +81,6 @@ class FaissNearestNeighbour(GlobalStep):
- [`The Faiss library`](https://arxiv.org/abs/2401.08281)
Examples:
-
Generating embeddings and getting the nearest neighbours:
```python
@@ -111,7 +114,6 @@ class FaissNearestNeighbour(GlobalStep):
```
Citations:
-
```
@misc{douze2024faisslibrary,
title={The Faiss library},
@@ -150,6 +152,10 @@ class FaissNearestNeighbour(GlobalStep):
description="The number of rows to include in a search batch. The value can be adjusted"
" to maximize the resources usage or to avoid OOM issues.",
)
+ train_size: Optional[RuntimeParameter[int]] = Field(
+ default=None,
+ description="If the index needs a training step, specifies how many vectors will be used to train the index.",
+ )
def load(self) -> None:
super().load()
@@ -178,14 +184,34 @@ def _build_index(self, inputs: List[Dict[str, Any]]) -> Dataset:
The build `datasets.Dataset` with its `faiss` index.
"""
dataset = Dataset.from_list(inputs)
+ if self.train_size is not None and self.string_factory:
+ self._logger.info("🏋️♀️ Starting Faiss index training...")
dataset.add_faiss_index(
column="embedding",
device=self.device, # type: ignore
string_factory=self.string_factory,
metric_type=self.metric_type,
+ train_size=self.train_size,
)
return dataset
+ def _save_index(self, dataset: Dataset) -> None:
+ """Save the generated Faiss index as an artifact of the step.
+
+ Args:
+ dataset: the dataset with the `faiss` index built.
+ """
+ self.save_artifact(
+ name="faiss_index",
+ write_function=lambda path: dataset.save_faiss_index(
+ index_name="embedding", file=path / "index.faiss"
+ ),
+ metadata={
+ "num_rows": len(dataset),
+ "embedding_dim": len(dataset[0]["embedding"]),
+ },
+ )
+
def _search(self, dataset: Dataset) -> Dataset:
"""Search the top `k` nearest neighbours for each row in the dataset.
@@ -214,5 +240,6 @@ def add_search_results(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
dataset = self._build_index(inputs)
- dataset = self._search(dataset)
- yield dataset.to_list()
+ dataset_with_search_results = self._search(dataset)
+ self._save_index(dataset)
+ yield dataset_with_search_results.to_list()
diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/steps/filtering/__init__.py
similarity index 61%
rename from src/distilabel/pipeline/constants.py
rename to src/distilabel/steps/filtering/__init__.py
index 3d400e4a1b..20ce00bda7 100644
--- a/src/distilabel/pipeline/constants.py
+++ b/src/distilabel/steps/filtering/__init__.py
@@ -11,12 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
-from typing import Final
-STEP_ATTR_NAME: Final[str] = "step"
-INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
-RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
-ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function"
-CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
-LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"
diff --git a/src/distilabel/steps/filtering/_datasketch.py b/src/distilabel/steps/filtering/_datasketch.py
new file mode 100644
index 0000000000..623cd45f0d
--- /dev/null
+++ b/src/distilabel/steps/filtering/_datasketch.py
@@ -0,0 +1,190 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+`dataskech` (https://github.com/ekzhu/datasketch) doesn't offer a way to store the hash tables in disk. This
+is a custom implementation that uses `shelve` to store the hash tables in disk.
+Note: This implementation is not optimized for performance, but could be worth
+creating a PR to `datasketch`.
+"""
+
+import shutil
+import struct
+from pathlib import Path
+from typing import Callable, Dict, Final, Optional, Tuple
+
+from datasketch import MinHashLSH as _MinHashLSH
+from datasketch.lsh import _optimal_param
+from datasketch.storage import OrderedStorage, UnorderedStorage, _random_name
+from datasketch.storage import ordered_storage as _ordered_storage
+from datasketch.storage import unordered_storage as _unordered_storage
+
+KEY_VALUE_DISK_DIR: Path = Path.home() / ".cache" / "distilabel" / "key_value_store"
+KV_DISK_LIST_NAME: Final[str] = "disckache_list_storage"
+KV_DISK_SET_NAME: Final[str] = "diskcache_set_storage"
+
+
+class DiskCacheListStorage(OrderedStorage):
+ def __init__(self, config, name) -> None:
+ path = config.get("path", self._get_db_name(name))
+ try:
+ from diskcache import Index
+ except ImportError as e:
+ raise ImportError(
+ "`diskcache` is required for disk storage using `MinHashDedup`. "
+ "Please install it using `pip install diskcache`."
+ ) from e
+
+ # Start with a clean file on each pipeline
+ if Path(path).exists():
+ shutil.rmtree(path)
+ self._db = Index(path)
+
+ def _get_db_name(self, name):
+ return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_LIST_NAME}")
+
+ def keys(self):
+ return self._db.keys()
+
+ def get(self, key):
+ return self._db.get(key, [])
+
+ def remove(self, *keys):
+ self._db.clear()
+
+ def remove_val(self, key, val):
+ self.get(key).remove(val)
+
+ def insert(self, key, *vals, **kwargs):
+ res = self.get(key)
+ res.extend(vals)
+ self._db[key] = res
+
+ def size(self):
+ return len(self._db)
+
+ def itemcounts(self, **kwargs):
+ return {k: len(v) for k, v in self._db.items()}
+
+ def has_key(self, key):
+ return key in self._db
+
+ def close(self):
+ self._db._cache.close()
+
+
+class DiskCacheSetStorage(UnorderedStorage, DiskCacheListStorage):
+ def _get_db_name(self, name):
+ return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_SET_NAME}")
+
+ def get(self, key):
+ return self._db.get(key, set())
+
+ def insert(self, key, *vals, **kwargs):
+ res = self.get(key)
+ res.update(vals)
+ self._db[key] = res
+
+
+def ordered_storage(config, name=None):
+ """Copy of `datasketch.storage.ordered_storage` with the addition of `ShelveListStorage`."""
+ tp = config["type"]
+ if tp == "disk":
+ return DiskCacheListStorage(config, name=name)
+ return _ordered_storage(config, name=name)
+
+
+def unordered_storage(config, name=None):
+ """Copy of `datasketch.storage.ordered_storage` with the addition of `ShelveSetStorage`."""
+ tp = config["type"]
+ if tp == "disk":
+ return DiskCacheSetStorage(config, name=name)
+ return _unordered_storage(config, name=name)
+
+
+class MinHashLSH(_MinHashLSH):
+ """Custom implementation of `datasketch.MinHashLSH` to allow passing a custom
+ storage configuration to store the hash tables in disk.
+
+ This could be merged in the original repository, the only changes
+ to the __init__ are the additional `close` method, and the use
+ of our custom `ordered_storage` and `unordered_storage` functions.
+ """
+
+ def __init__(
+ self,
+ threshold: float = 0.9,
+ num_perm: int = 128,
+ weights: Tuple[float, float] = (0.5, 0.5),
+ params: Optional[Tuple[int, int]] = None,
+ storage_config: Optional[Dict] = None,
+ prepickle: Optional[bool] = None,
+ hashfunc: Optional[Callable[[bytes], bytes]] = None,
+ ) -> None:
+ storage_config = {"type": "dict"} if not storage_config else storage_config
+ self._buffer_size = 50000
+ if threshold > 1.0 or threshold < 0.0:
+ raise ValueError("threshold must be in [0.0, 1.0]")
+ if num_perm < 2:
+ raise ValueError("Too few permutation functions")
+ if any(w < 0.0 or w > 1.0 for w in weights):
+ raise ValueError("Weight must be in [0.0, 1.0]")
+ if sum(weights) != 1.0:
+ raise ValueError("Weights must sum to 1.0")
+ self.h = num_perm
+ if params is not None:
+ self.b, self.r = params
+ if self.b * self.r > num_perm:
+ raise ValueError(
+ "The product of b and r in params is "
+ "{} * {} = {} -- it must be less than num_perm {}. "
+ "Did you forget to specify num_perm?".format(
+ self.b, self.r, self.b * self.r, num_perm
+ )
+ )
+ else:
+ false_positive_weight, false_negative_weight = weights
+ self.b, self.r = _optimal_param(
+ threshold, num_perm, false_positive_weight, false_negative_weight
+ )
+ if self.b < 2:
+ raise ValueError("The number of bands are too small (b < 2)")
+
+ self.prepickle = (
+ storage_config["type"] == "redis" if not prepickle else prepickle
+ )
+
+ self.hashfunc = hashfunc
+ if hashfunc:
+ self._H = self._hashed_byteswap
+ else:
+ self._H = self._byteswap
+
+ basename = storage_config.get("basename", _random_name(11))
+ self.hashtables = [
+ unordered_storage(
+ storage_config,
+ name=b"".join([basename, b"_bucket_", struct.pack(">H", i)]),
+ )
+ for i in range(self.b)
+ ]
+ self.hashranges = [(i * self.r, (i + 1) * self.r) for i in range(self.b)]
+ self.keys = ordered_storage(storage_config, name=b"".join([basename, b"_keys"]))
+
+ def close(self):
+ """Closes the internal connections."""
+ if isinstance(self.hashtables[0], DiskCacheListStorage):
+ for ht in self.hashtables:
+ ht.close()
+ self.keys.close()
diff --git a/src/distilabel/steps/filtering/embedding.py b/src/distilabel/steps/filtering/embedding.py
new file mode 100644
index 0000000000..cb1e710374
--- /dev/null
+++ b/src/distilabel/steps/filtering/embedding.py
@@ -0,0 +1,192 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+import numpy as np
+from pydantic import Field
+from rich.progress import track
+from typing_extensions import override
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps.base import GlobalStep, StepInput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
+
+
+class EmbeddingDedup(GlobalStep):
+ """Deduplicates text using embeddings.
+
+ `EmbeddingDedup` is a Step that detects near-duplicates in datasets, using
+ embeddings to compare the similarity between the texts. The typical workflow with this step
+ would include having a dataset with embeddings precomputed, and then (possibly using the
+ `FaissNearestNeighbour`) using the `nn_indices` and `nn_scores`, determine the texts that
+ are duplicate.
+
+ Attributes:
+ threshold: the threshold to consider 2 examples as duplicates.
+ It's dependent on the type of index that was used to generate the embeddings.
+ For example, if the embeddings were generated using cosine similarity, a threshold
+ of `0.9` would make all the texts with a cosine similarity above the value
+ duplicates. Higher values detect less duplicates in such an index, but that should
+ be taken into account when building it. Defaults to `0.9`.
+
+ Runtime Parameters:
+ - `threshold`: the threshold to consider 2 examples as duplicates.
+
+ Input columns:
+ - nn_indices (`List[int]`): a list containing the indices of the `k` nearest neighbours
+ in the inputs for the row.
+ - nn_scores (`List[float]`): a list containing the score or distance to each `k`
+ nearest neighbour in the inputs.
+
+ Output columns:
+ - keep_row_after_embedding_filtering (`bool`): boolean indicating if the piece `text` is
+ not a duplicate i.e. this text should be kept.
+
+ Categories:
+ - filtering
+
+ Examples:
+
+ Deduplicate a list of texts using embedding information:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps import EmbeddingDedup
+ from distilabel.steps import LoadDataFromDicts
+
+ with Pipeline() as pipeline:
+ data = LoadDataFromDicts(
+ data=[
+ {
+ "persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.",
+ "embedding": [
+ 0.018477669046149742,
+ -0.03748236608841726,
+ 0.001919870620352492,
+ 0.024918478063770535,
+ 0.02348063521315178,
+ 0.0038251285566308375,
+ -0.01723884983037716,
+ 0.02881971942372201,
+ ],
+ "nn_indices": [0, 1],
+ "nn_scores": [
+ 0.9164746999740601,
+ 0.782106876373291,
+ ],
+ },
+ {
+ "persona": "A music teacher or instructor focused on theoretical and practical piano lessons.",
+ "embedding": [
+ -0.0023464179614082125,
+ -0.07325472251663565,
+ -0.06058678419516501,
+ -0.02100326928586996,
+ -0.013462744792362657,
+ 0.027368447064244242,
+ -0.003916070100455717,
+ 0.01243614518480423,
+ ],
+ "nn_indices": [0, 2],
+ "nn_scores": [
+ 0.7552462220191956,
+ 0.7261884808540344,
+ ],
+ },
+ {
+ "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
+ "embedding": [
+ -0.01630817942328242,
+ -0.023760151552345232,
+ -0.014249650090627883,
+ -0.005713686451446624,
+ -0.016033059279131567,
+ 0.0071440908501058786,
+ -0.05691099643425161,
+ 0.01597412704817784,
+ ],
+ "nn_indices": [1, 2],
+ "nn_scores": [
+ 0.8107735514640808,
+ 0.7172299027442932,
+ ],
+ },
+ ],
+ batch_size=batch_size,
+ )
+ # In general you should do something like this before the deduplication step, to obtain the
+ # `nn_indices` and `nn_scores`. In this case the embeddings are already normalized, so there's
+ # no need for it.
+ # nn = FaissNearestNeighbour(
+ # k=30,
+ # metric_type=faiss.METRIC_INNER_PRODUCT,
+ # search_batch_size=50,
+ # train_size=len(dataset), # The number of embeddings to use for training
+ # string_factory="IVF300_HNSW32,Flat" # To use an index (optional, maybe required for big datasets)
+ # )
+ # Read more about the `string_factory` here:
+ # https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
+
+ embedding_dedup = EmbeddingDedup(
+ threshold=0.8,
+ input_batch_size=batch_size,
+ )
+
+ data >> embedding_dedup
+
+ if __name__ == "__main__":
+ distiset = pipeline.run(use_cache=False)
+ ds = distiset["default"]["train"]
+ # Filter out the duplicates
+ ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"])
+ ```
+ """
+
+ threshold: Optional[RuntimeParameter[float]] = Field(
+ default=0.9,
+ description="The threshold to consider 2 examples as duplicates. It's dependent "
+ "on the type of index that was used to generate the embeddings. For example, if "
+ "the embeddings were generated using cosine similarity, a threshold of `0.9` "
+ "would make all the texts with a cosine similarity above the value duplicates. "
+ "Higher values detect less duplicates in such an index, but that should be "
+ "taken into account when building it.",
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["nn_scores", "nn_indices"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["keep_row_after_embedding_filtering"]
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
+ rows_to_remove = set()
+
+ for input in track(inputs, description="Running Embedding deduplication..."):
+ input["keep_row_after_embedding_filtering"] = True
+ indices_scores = np.array(input["nn_scores"]) > self.threshold
+ indices = np.array(input["nn_indices"])[indices_scores]
+ if len(indices) > 0: # If there are any rows found over the threshold
+ rows_to_remove.update(list(indices))
+
+ # Remove duplicates and get the list of rows to remove
+ for idx in rows_to_remove:
+ inputs[idx]["keep_row_after_embedding_filtering"] = False
+
+ yield inputs
diff --git a/src/distilabel/steps/filtering/minhash.py b/src/distilabel/steps/filtering/minhash.py
new file mode 100644
index 0000000000..5b779168a1
--- /dev/null
+++ b/src/distilabel/steps/filtering/minhash.py
@@ -0,0 +1,236 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import uuid
+from functools import partial
+from itertools import tee
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Iterable,
+ Iterator,
+ List,
+ Literal,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
+from pydantic import PrivateAttr
+
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from datasketch import MinHash, MinHashLSH
+
+ from distilabel.steps.typing import StepOutput
+
+
+# Copied from: https://github.com/huggingface/datatrove/blob/main/src/datatrove/utils/text.py#L89C1-L95C65
+def ngrams(sequence: Iterable[str], n: int) -> Iterator[Tuple[str, ...]]:
+ iterables = tee(sequence, n)
+
+ for i, sub_iterable in enumerate(iterables): # For each window,
+ for _ in range(i): # iterate through every order of ngrams
+ next(sub_iterable, None) # generate the ngrams within the window.
+ return zip(*iterables) # Unpack and flattens the iterables.
+
+
+def tokenized_on_words(texts: Iterable[str]) -> List[Set[bytes]]:
+ """Tokenizes a list of texts into words, using `nltk.word_tokenize`.
+
+ Args:
+ texts: List of documents to be tokenized.
+
+ Returns:
+ List with the set of tokens for each document.
+ """
+ from nltk.tokenize import word_tokenize
+
+ return [{w.encode("utf-8") for w in word_tokenize(text)} for text in texts]
+
+
+def tokenize_on_ngrams(texts: Iterable[str], n: int = 1) -> List[Set[bytes]]:
+ """Tokenizes a list of texts into ngrams, and returns the set of them as bytes.
+
+ Args:
+ texts: List of documents to be tokenized.
+ n: The size of the ngrams, defaults to 1 (single letters).
+
+ Returns:
+ List with the set of tokens for each document.
+ """
+
+ return [
+ {"".join(ngram).encode("utf-8") for ngram in ngrams(text, n=n)}
+ for text in texts
+ ]
+
+
+class MinHashDedup(Step):
+ """Deduplicates text using `MinHash` and `MinHashLSH`.
+
+ `MinHashDedup` is a Step that detects near-duplicates in datasets. The idea roughly translates
+ to the following steps:
+ 1. Tokenize the text into words or ngrams.
+ 2. Create a `MinHash` for each text.
+ 3. Store the `MinHashes` in a `MinHashLSH`.
+ 4. Check if the `MinHash` is already in the `LSH`, if so, it is a duplicate.
+
+ Attributes:
+ num_perm: the number of permutations to use. Defaults to `128`.
+ seed: the seed to use for the MinHash. This seed must be the same
+ used for `MinHash`, keep in mind when both steps are created. Defaults to `1`.
+ tokenizer: the tokenizer to use. Available ones are `words` or `ngrams`.
+ If `words` is selected, it tokenize the text into words using nltk's
+ word tokenizer. `ngram` estimates the ngrams (together with the size
+ `n`) using. Defaults to `words`.
+ n: the size of the ngrams to use. Only relevant if `tokenizer="ngrams"`. Defaults to `5`.
+ threshold: the threshold to consider two MinHashes as duplicates.
+ Values closer to 0 detect more duplicates. Defaults to `0.9`.
+ storage: the storage to use for the LSH. Can be `dict` to store the index
+ in memory, or `disk`. Keep in mind, `disk` is an experimental feature
+ not defined in `datasketch`, that is based on DiskCache's `Index` class.
+ It should work as a `dict`, but backed by disk, but depending on the system
+ it can be slower. Defaults to `dict`.
+ which uses a custom `shelve` backend. Note the `disk`
+ is an experimetal feature that may cause issues. Defaults to `dict`.
+
+ Input columns:
+ - text (`str`): the texts to be filtered.
+
+ Output columns:
+ - keep_row_after_minhash_filtering (`bool`): boolean indicating if the piece `text` is
+ not a duplicate i.e. this text should be kept.
+
+ Categories:
+ - filtering
+
+ References:
+ - [`datasketch documentation`](https://ekzhu.github.io/datasketch/lsh.html)
+ - [Identifying and Filtering Near-Duplicate Documents](https://cs.brown.edu/courses/cs253/papers/nearduplicate.pdf)
+ - [Diskcache's Index](https://grantjenks.com/docs/diskcache/api.html#diskcache.Index)
+
+ Examples:
+
+ Deduplicate a list of texts using MinHash and MinHashLSH:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps import MinHashDedup
+ from distilabel.steps import LoadDataFromDicts
+
+ with Pipeline() as pipeline:
+ ds_size = 1000
+ batch_size = 500 # Bigger batch sizes work better for this step
+ data = LoadDataFromDicts(
+ data=[
+ {"text": "This is a test document."},
+ {"text": "This document is a test."},
+ {"text": "Test document for duplication."},
+ {"text": "Document for duplication test."},
+ {"text": "This is another unique document."},
+ ]
+ * (ds_size // 5),
+ batch_size=batch_size,
+ )
+ minhash_dedup = MinHashDedup(
+ tokenizer="words",
+ threshold=0.9, # lower values will increase the number of duplicates
+ storage="dict", # or "disk" for bigger datasets
+ )
+
+ data >> minhash_dedup
+
+ if __name__ == "__main__":
+ distiset = pipeline.run(use_cache=False)
+ ds = distiset["default"]["train"]
+ # Filter out the duplicates
+ ds_dedup = ds.filter(lambda x: x["keep_row_after_minhash_filtering"])
+ ```
+ """
+
+ num_perm: int = 128
+ seed: int = 1
+ tokenizer: Literal["words", "ngrams"] = "words"
+ n: Optional[int] = 5
+ threshold: float = 0.9
+ storage: Literal["dict", "disk"] = "dict"
+
+ _hasher: Union["MinHash", None] = PrivateAttr(None)
+ _tokenizer: Union[Callable, None] = PrivateAttr(None)
+ _lhs: Union["MinHashLSH", None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ super().load()
+ if not importlib.import_module("datasketch"):
+ raise ImportError(
+ "`datasketch` is needed to deduplicate with MinHash, but is not installed. "
+ "Please install it using `pip install datasketch`."
+ )
+ from datasketch import MinHash
+
+ from distilabel.steps.filtering._datasketch import MinHashLSH
+
+ self._hasher = MinHash.bulk
+ self._lsh = MinHashLSH(
+ num_perm=self.num_perm,
+ threshold=self.threshold,
+ storage_config={"type": self.storage},
+ )
+
+ if self.tokenizer == "words":
+ if not importlib.import_module("nltk"):
+ raise ImportError(
+ "`nltk` is needed to tokenize based on words, but is not installed. "
+ "Please install it using `pip install nltk`. Then run `nltk.download('punkt_tab')`."
+ )
+ self._tokenizer = tokenized_on_words
+ else:
+ self._tokenizer = partial(tokenize_on_ngrams, n=self.n)
+
+ def unload(self) -> None:
+ super().unload()
+ # In case of LSH being stored in disk, we need to close the file.
+ if self.storage == "disk":
+ self._lsh.close()
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["text"]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["keep_row_after_minhash_filtering"]
+
+ def process(self, inputs: StepInput) -> "StepOutput":
+ tokenized_texts = []
+ for input in inputs:
+ tokenized_texts.append(self._tokenizer([input[self.inputs[0]]])[0])
+
+ minhashes = self._hasher(
+ tokenized_texts, num_perm=self.num_perm, seed=self.seed
+ )
+
+ for input, minhash in zip(inputs, minhashes):
+ # Check if the text is already in the LSH index
+ if self._lsh.query(minhash):
+ input["keep_row_after_minhash_filtering"] = False
+ else:
+ self._lsh.insert(str(uuid.uuid4()), minhash)
+ input["keep_row_after_minhash_filtering"] = True
+
+ yield inputs
diff --git a/src/distilabel/steps/formatting/conversation.py b/src/distilabel/steps/formatting/conversation.py
index 22cc582c54..29381521bd 100644
--- a/src/distilabel/steps/formatting/conversation.py
+++ b/src/distilabel/steps/formatting/conversation.py
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class ConversationTemplate(Step):
@@ -36,7 +36,6 @@ class ConversationTemplate(Step):
- template
Examples:
-
Create a conversation from an instruction and a response:
```python
@@ -61,12 +60,12 @@ class ConversationTemplate(Step):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The instruction and response."""
return ["instruction", "response"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The conversation template."""
return ["conversation"]
diff --git a/src/distilabel/steps/formatting/dpo.py b/src/distilabel/steps/formatting/dpo.py
index 9402436ee9..72253eb194 100644
--- a/src/distilabel/steps/formatting/dpo.py
+++ b/src/distilabel/steps/formatting/dpo.py
@@ -18,7 +18,7 @@
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class FormatTextGenerationDPO(Step):
@@ -65,7 +65,6 @@ class FormatTextGenerationDPO(Step):
- generations
Examples:
-
Format your dataset for DPO fine tuning:
```python
@@ -103,10 +102,16 @@ class FormatTextGenerationDPO(Step):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, `generations`,
and `ratings`."""
- return ["instruction", "generations", "ratings"]
+ return {
+ "system_prompt": False,
+ "instruction": True,
+ "generations": True,
+ "generation_models": False,
+ "ratings": True,
+ }
@property
def optional_inputs(self) -> List[str]:
@@ -115,7 +120,7 @@ def optional_inputs(self) -> List[str]:
return ["system_prompt", "generation_models"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
`chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
@@ -191,12 +196,11 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
class FormatChatGenerationDPO(Step):
- """Format the output of a combination of a `ChatGeneration` + a preference task such as
- `UltraFeedback`, for Direct Preference Optimization (DPO) following the standard formatting
- from frameworks such as `axolotl` or `alignment-handbook`.
+ """Format the output of a combination of a `ChatGeneration` + a preference task for Direct Preference Optimization (DPO).
`FormatChatGenerationDPO` is a `Step` that formats the output of the combination of a `ChatGeneration`
- task with a preference `Task` i.e. a task generating `ratings`, so that those are used to rank the
+ task with a preference `Task` i.e. a task generating `ratings` such as `UltraFeedback` following the standard
+ formatting from frameworks such as `axolotl` or `alignment-handbook`., so that those are used to rank the
existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.
Note:
@@ -233,7 +237,6 @@ class FormatChatGenerationDPO(Step):
- generations
Examples:
-
Format your dataset for DPO fine tuning:
```python
@@ -272,7 +275,7 @@ class FormatChatGenerationDPO(Step):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `messages`, `generations`,
and `ratings`."""
return ["messages", "generations", "ratings"]
@@ -284,7 +287,7 @@ def optional_inputs(self) -> List[str]:
return ["generation_models"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
`chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
diff --git a/src/distilabel/steps/formatting/sft.py b/src/distilabel/steps/formatting/sft.py
index ec93aadf79..2793b212e6 100644
--- a/src/distilabel/steps/formatting/sft.py
+++ b/src/distilabel/steps/formatting/sft.py
@@ -18,7 +18,7 @@
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class FormatTextGenerationSFT(Step):
@@ -50,7 +50,6 @@ class FormatTextGenerationSFT(Step):
- generation
Examples:
-
Format your dataset for SFT fine tuning:
```python
@@ -84,9 +83,13 @@ class FormatTextGenerationSFT(Step):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
- return ["instruction", "generation"]
+ return {
+ "system_prompt": False,
+ "instruction": True,
+ "generation": True,
+ }
@property
def optional_inputs(self) -> List[str]:
@@ -95,7 +98,7 @@ def optional_inputs(self) -> List[str]:
return ["system_prompt"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.
Reference:
@@ -139,8 +142,7 @@ def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
class FormatChatGenerationSFT(Step):
- """Format the output of a `ChatGeneration` task for Supervised Fine-Tuning (SFT) following the
- standard formatting from frameworks such as `axolotl` or `alignment-handbook`.
+ """Format the output of a `ChatGeneration` task for Supervised Fine-Tuning (SFT).
`FormatChatGenerationSFT` is a `Step` that formats the output of a `ChatGeneration` task for
Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as `axolotl`
@@ -168,8 +170,7 @@ class FormatChatGenerationSFT(Step):
- generation
Examples:
-
- Format your dataset for Supervised Fine Tuning (SFT):
+ Format your dataset for SFT:
```python
from distilabel.steps import FormatChatGenerationSFT
@@ -201,12 +202,12 @@ class FormatChatGenerationSFT(Step):
"""
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
return ["messages", "generation"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.
Reference:
diff --git a/src/distilabel/steps/generators/data.py b/src/distilabel/steps/generators/data.py
index fbf29ec7fe..803ee35eac 100644
--- a/src/distilabel/steps/generators/data.py
+++ b/src/distilabel/steps/generators/data.py
@@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, List
+from pydantic import Field
from typing_extensions import override
from distilabel.steps.base import GeneratorStep
@@ -42,7 +43,6 @@ class LoadDataFromDicts(GeneratorStep):
- load
Examples:
-
Load data from a list of dictionaries:
```python
@@ -60,7 +60,7 @@ class LoadDataFromDicts(GeneratorStep):
```
"""
- data: List[Dict[str, Any]]
+ data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True)
@override
def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
diff --git a/src/distilabel/steps/generators/data_sampler.py b/src/distilabel/steps/generators/data_sampler.py
new file mode 100644
index 0000000000..6b2e55bf02
--- /dev/null
+++ b/src/distilabel/steps/generators/data_sampler.py
@@ -0,0 +1,179 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from itertools import islice
+from typing import TYPE_CHECKING, Any, Dict, List
+
+from pydantic import Field
+from typing_extensions import override
+
+from distilabel.steps.base import GeneratorStep
+
+if TYPE_CHECKING:
+ from distilabel.steps.base import GeneratorStepOutput
+
+
+class DataSampler(GeneratorStep):
+ """Step to sample from a dataset.
+
+ `GeneratorStep` that samples from a dataset and yields it in batches.
+ This step is useful when you have a pipeline that can benefit from using examples
+ in the prompts for example as few-shot learning, that can be changing on each row.
+ For example, you can pass a list of dictionaries with N examples and generate M samples
+ from it (assuming you have another step loading data, this M should have the same size
+ as the data being loaded in that step). The size S argument is the number of samples per
+ row generated, so each example would contain S examples to be used as examples.
+
+ Attributes:
+ data: The list of dictionaries to sample from.
+ size: Number of samples per example. For example in a few-shot learning scenario,
+ the number of few-shot examples that will be generated per example. Defaults to 2.
+ samples: Number of examples that will be generated by the step in total.
+ If used with another loader step, this should be the same as the number
+ of samples in the loader step. Defaults to 100.
+
+ Output columns:
+ - dynamic (based on the keys found on the first dictionary of the list): The columns
+ of the dataset.
+
+ Categories:
+ - load
+
+ Examples:
+ Sample data from a list of dictionaries:
+
+ ```python
+ from distilabel.steps import DataSampler
+
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ samples=10,
+ size=2,
+ batch_size=4
+ )
+ sampler.load()
+
+ result = next(sampler.process())
+ # >>> result
+ # ([{'sample': ['sample 7', 'sample 0']}, {'sample': ['sample 2', 'sample 21']}, {'sample': ['sample 17', 'sample 12']}, {'sample': ['sample 2', 'sample 14']}], False)
+ ```
+
+ Pipeline with a loader and a sampler combined in a single stream:
+
+ ```python
+ from datasets import load_dataset
+
+ from distilabel.steps import LoadDataFromDicts, DataSampler
+ from distilabel.steps.tasks.apigen.utils import PrepareExamples
+ from distilabel.pipeline import Pipeline
+
+ ds = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+ )
+ data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ ]
+ with Pipeline(name="APIGenPipeline") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data)
+ sampler = DataSampler(
+ data=ds,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+ prep_examples = PrepareExamples()
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ )
+ # Now we have a single stream of data with the loader and the sampler data
+ ```
+ """
+
+ data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True)
+ size: int = Field(
+ default=2,
+ description=(
+ "Number of samples per example. For example in a few-shot learning scenario, the number "
+ "of few-shot examples that will be generated per example."
+ ),
+ )
+ samples: int = Field(
+ default=100,
+ description=(
+ "Number of examples that will be generated by the step in total. "
+ "If used with another loader step, this should be the same as the number of "
+ "samples in the loader step."
+ ),
+ )
+
+ @override
+ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
+ """Yields batches from a list of dictionaries.
+
+ Args:
+ offset: The offset to start the generation from. Defaults to `0`.
+
+ Yields:
+ A list of Python dictionaries as read from the inputs (propagated in batches)
+ and a flag indicating whether the yield batch is the last one.
+ """
+
+ total_samples = 0
+
+ while total_samples < self.samples:
+ batch = []
+ bs = min(self.batch_size, self.samples - total_samples)
+ for _ in range(self.batch_size):
+ choices = random.choices(self.data, k=self.size)
+ choices = self._transform_data(choices)
+ batch.extend(choices)
+ total_samples += bs
+ batch = list(islice(batch, bs))
+ yield (batch, True if total_samples >= self.samples else False)
+ batch = []
+
+ @staticmethod
+ def _transform_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ if not data:
+ return []
+
+ result = {key: [] for key in data[0].keys()}
+
+ for item in data:
+ for key, value in item.items():
+ result[key].append(value)
+
+ return [result]
+
+ @property
+ def outputs(self) -> List[str]:
+ return list(self.data[0].keys())
diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py
index b31b9fbadc..721b3d4081 100644
--- a/src/distilabel/steps/generators/huggingface.py
+++ b/src/distilabel/steps/generators/huggingface.py
@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
from collections import defaultdict
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
+ Annotated,
Any,
Dict,
List,
@@ -24,6 +26,7 @@
Optional,
Sequence,
Tuple,
+ TypeVar,
Union,
)
@@ -39,6 +42,7 @@
from upath import UPath
from distilabel.distiset import Distiset
+from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import GeneratorStep
@@ -46,6 +50,13 @@
from distilabel.steps.typing import GeneratorStepOutput
+T = TypeVar("T")
+
+# To avoid using repo_id in LoadDataFromFileSystem:
+# https://github.com/pydantic/pydantic/discussions/7076#discussioncomment-6699138
+ExcludedField = Annotated[T, Field(exclude=True)]
+
+
class LoadDataFromHub(GeneratorStep):
"""Loads a dataset from the Hugging Face Hub.
@@ -64,6 +75,7 @@ class LoadDataFromHub(GeneratorStep):
- `split`: The split of the dataset to load. Defaults to 'train'.
- `config`: The configuration of the dataset to load. This is optional and only
needed if the dataset has multiple configurations.
+ - `revision`: The revision of the dataset to load. Defaults to the latest revision.
- `streaming`: Whether to load the dataset in streaming mode or not. Defaults to
`False`.
- `num_examples`: The number of examples to load from the dataset.
@@ -79,7 +91,6 @@ class LoadDataFromHub(GeneratorStep):
- load
Examples:
-
Load data from a dataset in Hugging Face Hub:
```python
@@ -112,6 +123,10 @@ class LoadDataFromHub(GeneratorStep):
description="The configuration of the dataset to load. This is optional and only"
" needed if the dataset has multiple configurations.",
)
+ revision: Optional[RuntimeParameter[str]] = Field(
+ default=None,
+ description="The revision of the dataset to load. Defaults to the latest revision.",
+ )
streaming: RuntimeParameter[bool] = Field(
default=False,
description="Whether to load the dataset in streaming mode or not. Defaults to False.",
@@ -139,6 +154,7 @@ def load(self) -> None:
self.repo_id, # type: ignore
self.config,
split=self.split,
+ revision=self.revision,
streaming=self.streaming,
)
num_examples = self._get_dataset_num_examples()
@@ -203,11 +219,11 @@ def _get_dataset_num_examples(self) -> int:
Returns:
The number of examples in the dataset.
"""
- return (
- self._dataset_info[self.config if self.config else "default"]
- .splits[self.split]
- .num_examples
- )
+ default_config = self.config
+ if not default_config:
+ default_config = list(self._dataset_info.keys())[0]
+
+ return self._dataset_info[default_config].splits[self.split].num_examples
def _get_dataset_columns(self) -> List[str]:
"""Get the columns of the dataset, based on the `config` runtime parameter provided.
@@ -228,20 +244,18 @@ def _dataset_info(self) -> Dict[str, DatasetInfo]:
Returns:
The dataset information.
"""
- repo_id = self.repo_id
- config = self.config
try:
- return get_dataset_infos(repo_id)
+ return get_dataset_infos(self.repo_id)
except Exception as e:
- # The previous could fail in case of a internet connection issues.
- # Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway.
- self._logger.warning(
- f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}"
+ warnings.warn(
+ f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}",
+ UserWarning,
+ stacklevel=2,
)
- ds = load_dataset(repo_id, config=self.config, split=self.split)
- if config:
- return ds[config].info
+ ds = load_dataset(self.repo_id, config=self.config, split=self.split)
+ if self.config:
+ return ds[self.config].info
return ds.info
@@ -279,7 +293,6 @@ class LoadDataFromFileSystem(LoadDataFromHub):
- load
Examples:
-
Load data from a Hugging Face dataset in your file system:
```python
@@ -324,6 +337,23 @@ class LoadDataFromFileSystem(LoadDataFromHub):
# >>> result
# ([{'type': 'function', 'function':...', False)
```
+
+ Load data passing a glob pattern:
+
+ ```python
+ from distilabel.steps import LoadDataFromFileSystem
+
+ loader = LoadDataFromFileSystem(
+ data_files="path/to/dataset/*.jsonl",
+ streaming=True
+ )
+ loader.load()
+
+ # Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
+ result = next(loader.process())
+ # >>> result
+ # ([{'type': 'function', 'function':...', False)
+ ```
"""
data_files: RuntimeParameter[Union[str, Path]] = Field(
@@ -334,6 +364,7 @@ class LoadDataFromFileSystem(LoadDataFromHub):
default=None,
description="The expected filetype. If not provided, it will be inferred from the file extension.",
)
+ repo_id: ExcludedField[Union[str, None]] = None
def load(self) -> None:
"""Load the dataset from the file/s in disk."""
@@ -366,7 +397,7 @@ def load(self) -> None:
self.num_examples = len(self._dataset)
@staticmethod
- def _prepare_data_files(
+ def _prepare_data_files( # noqa: C901
data_path: UPath,
) -> Tuple[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], str]:
"""Prepare the loading process by setting the `data_files` attribute.
@@ -384,9 +415,12 @@ def get_filetype(data_path: UPath) -> str:
filetype = "json"
return filetype
- if data_path.is_file():
+ if data_path.is_file() or (
+ len(str(data_path.parent.glob(data_path.name))) >= 1
+ ):
filetype = get_filetype(data_path)
data_files = str(data_path)
+
elif data_path.is_dir():
file_sequence = []
file_map = defaultdict(list)
@@ -416,9 +450,7 @@ def outputs(self) -> List[str]:
"""
# We assume there are Dataset/IterableDataset, not it's ...Dict counterparts
if self._dataset is None:
- raise ValueError(
- "Dataset not loaded yet, you must call `load` method first."
- )
+ self.load()
return self._dataset.column_names
@@ -432,16 +464,16 @@ class LoadDataFromDisk(LoadDataFromHub):
Attributes:
dataset_path: The path to the dataset or distiset.
split: The split of the dataset to load (typically will be `train`, `test` or `validation`).
- config: The configuration of the dataset to load. This is optional and only needed
- if the dataset has multiple configurations.
+ config: The configuration of the dataset to load. Defaults to `default`, if there are
+ multiple configurations in the dataset this must be suplied or an error is raised.
Runtime parameters:
- `batch_size`: The batch size to use when processing the data.
- `dataset_path`: The path to the dataset or distiset.
- `is_distiset`: Whether the dataset to load is a `Distiset` or not. Defaults to False.
- `split`: The split of the dataset to load. Defaults to 'train'.
- - `config`: The configuration of the dataset to load. This is optional and only
- needed if the dataset has multiple configurations.
+ - `config`: The configuration of the dataset to load. Defaults to `default`, if there are
+ multiple configurations in the dataset this must be suplied or an error is raised.
- `num_examples`: The number of examples to load from the dataset.
By default will load all examples.
- `storage_options`: Key/value pairs to be passed on to the file-system backend, if any.
@@ -455,7 +487,6 @@ class LoadDataFromDisk(LoadDataFromHub):
- load
Examples:
-
Load data from a Hugging Face Dataset:
```python
@@ -511,10 +542,12 @@ class LoadDataFromDisk(LoadDataFromHub):
default=None,
description="Path to the dataset or distiset.",
)
- config: RuntimeParameter[str] = Field(
- default=None,
- description="The configuration of the dataset to load. This is optional and only"
- " needed if the dataset has multiple configurations.",
+ config: Optional[RuntimeParameter[str]] = Field(
+ default="default",
+ description=(
+ "The configuration of the dataset to load. Will default to 'default'",
+ " which corresponds to a distiset with a single configuration.",
+ ),
)
is_distiset: Optional[RuntimeParameter[bool]] = Field(
default=False,
@@ -529,6 +562,7 @@ class LoadDataFromDisk(LoadDataFromHub):
default=None,
description="The split of the dataset to load. By default will load the whole Dataset/Distiset.",
)
+ repo_id: ExcludedField[Union[str, None]] = None
def load(self) -> None:
"""Load the dataset from the file/s in disk."""
@@ -539,8 +573,14 @@ def load(self) -> None:
keep_in_memory=self.keep_in_memory,
storage_options=self.storage_options,
)
- if self.config:
- ds = ds[self.config]
+ if self.config not in ds.keys():
+ raise DistilabelUserError(
+ f"Configuration '{self.config}' not found in the Distiset, available ones"
+ f" are: {list(ds.keys())}. Please try changing the `config` parameter to one "
+ "of the available configurations.\n\n",
+ page="sections/how_to_guides/advanced/distiset/#using-the-distiset-dataset-object",
+ )
+ ds = ds[self.config]
else:
ds = load_from_disk(
@@ -568,9 +608,7 @@ def outputs(self) -> List[str]:
The columns that will be generated by this step.
"""
# We assume there are Dataset/IterableDataset, not it's ...Dict counterparts
- if self._dataset is Ellipsis:
- raise ValueError(
- "Dataset not loaded yet, you must call `load` method first."
- )
+ if self._dataset is None:
+ self.load()
return self._dataset.column_names
diff --git a/src/distilabel/steps/generators/utils.py b/src/distilabel/steps/generators/utils.py
index b9e111d9b9..49d27748b4 100644
--- a/src/distilabel/steps/generators/utils.py
+++ b/src/distilabel/steps/generators/utils.py
@@ -17,18 +17,22 @@
import pandas as pd
from datasets import Dataset
+from distilabel.errors import DistilabelUserError
from distilabel.steps.base import StepResources
if TYPE_CHECKING:
+ from distilabel.pipeline.base import BasePipeline
from distilabel.steps import GeneratorStep
def make_generator_step(
dataset: Union[Dataset, pd.DataFrame, List[Dict[str, str]]],
+ pipeline: Union["BasePipeline", None] = None,
batch_size: int = 50,
input_mappings: Optional[Dict[str, str]] = None,
output_mappings: Optional[Dict[str, str]] = None,
resources: StepResources = StepResources(),
+ repo_id: Optional[str] = "default_name",
) -> "GeneratorStep":
"""Helper method to create a `GeneratorStep` from a dataset, to simplify
@@ -39,6 +43,9 @@ def make_generator_step(
input_mappings: Applies the same as any other step. Defaults to `None`.
output_mappings: Applies the same as any other step. Defaults to `None`.
resources: Applies the same as any other step. Defaults to `StepResources()`.
+ repo_id: The repository ID to use in the `LoadDataFromHub` step.
+ This shouldn't be necessary, but in case of error, the dataset will try to be loaded
+ using `load_dataset` internally. If that case happens, the `repo_id` will be used.
Raises:
ValueError: If the format is different from the ones supported.
@@ -51,6 +58,7 @@ def make_generator_step(
if isinstance(dataset, list):
return LoadDataFromDicts(
+ pipeline=pipeline,
data=dataset,
batch_size=batch_size,
input_mappings=input_mappings or {},
@@ -62,18 +70,21 @@ def make_generator_step(
dataset = Dataset.from_pandas(dataset, preserve_index=False)
if not isinstance(dataset, Dataset):
- raise ValueError(
+ raise DistilabelUserError(
f"Dataset type not allowed: {type(dataset)}, must be one of: "
- "`datasets.Dataset`, `pd.DataFrame`, `List[Dict[str, str]]`"
+ "`datasets.Dataset`, `pd.DataFrame`, `List[Dict[str, str]]`",
+ page="sections/how_to_guides/basic/pipeline/?h=make_#__tabbed_1_2",
)
loader = LoadDataFromHub(
- repo_id="placeholder_name",
+ pipeline=pipeline,
+ repo_id=repo_id,
batch_size=batch_size,
input_mappings=input_mappings or {},
output_mappings=output_mappings or {},
resources=resources,
)
+ super(loader.__class__, loader).load() # Ensure the logger is loaded
loader._dataset = dataset
loader.num_examples = len(dataset)
loader._dataset_info = {"default": dataset.info}
diff --git a/src/distilabel/steps/globals/huggingface.py b/src/distilabel/steps/globals/huggingface.py
index 28ef3932bd..82e7f35ab6 100644
--- a/src/distilabel/steps/globals/huggingface.py
+++ b/src/distilabel/steps/globals/huggingface.py
@@ -58,7 +58,6 @@ class PushToHub(GlobalStep):
- huggingface
Examples:
-
Push batches of your dataset to the Hugging Face Hub repository:
```python
diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py
index 1617d86bc3..49ddc065df 100644
--- a/src/distilabel/steps/reward_model.py
+++ b/src/distilabel/steps/reward_model.py
@@ -26,7 +26,7 @@
from transformers import PreTrainedModel, PreTrainedTokenizer
from distilabel.steps.tasks.typing import ChatType
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class RewardModelScore(Step, CudaDevicePlacementMixin):
@@ -68,7 +68,6 @@ class RewardModelScore(Step, CudaDevicePlacementMixin):
- scorer
Examples:
-
Assigning an score for an instruction-response pair:
```python
@@ -179,12 +178,16 @@ def load(self) -> None:
)
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""Either `response` and `instruction`, or a `conversation` columns."""
- return []
+ return {
+ "response": False,
+ "instruction": False,
+ "conversation": False,
+ }
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The `score` given by the reward model."""
return ["score"]
diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py
index 0b3a69596b..725fd065fd 100644
--- a/src/distilabel/steps/tasks/__init__.py
+++ b/src/distilabel/steps/tasks/__init__.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker
+from distilabel.steps.tasks.apigen.generator import APIGenGenerator
+from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker
+from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller
from distilabel.steps.tasks.base import GeneratorTask, Task
+from distilabel.steps.tasks.clair import CLAIR
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
from distilabel.steps.tasks.evol_instruct.evol_complexity.base import EvolComplexity
@@ -43,13 +48,19 @@
from distilabel.steps.tasks.self_instruct import SelfInstruct
from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair
from distilabel.steps.tasks.structured_generation import StructuredGeneration
+from distilabel.steps.tasks.text_classification import TextClassification
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
+from distilabel.steps.tasks.urial import URIAL
__all__ = [
"GeneratorTask",
"Task",
+ "ArgillaLabeller",
+ "APIGenExecutionChecker",
+ "APIGenGenerator",
+ "APIGenSemanticChecker",
"ComplexityScorer",
"EvolInstruct",
"EvolComplexity",
@@ -74,9 +85,12 @@
"SelfInstruct",
"GenerateSentencePair",
"StructuredGeneration",
+ "TextClassification",
"ChatGeneration",
"TextGeneration",
"ChatItem",
"ChatType",
+ "CLAIR",
"UltraFeedback",
+ "URIAL",
]
diff --git a/src/distilabel/steps/tasks/apigen/__init__.py b/src/distilabel/steps/tasks/apigen/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/distilabel/steps/tasks/apigen/execution_checker.py b/src/distilabel/steps/tasks/apigen/execution_checker.py
new file mode 100644
index 0000000000..7d30dd1f75
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/execution_checker.py
@@ -0,0 +1,268 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# - Try to import the function from a given module
+# - If function, try to import it and run it
+# - If fails, track the error message, and return it
+
+import inspect
+import json
+from pathlib import Path
+from typing import TYPE_CHECKING, Callable, Union
+
+from pydantic import Field, PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.base import Step, StepInput
+from distilabel.steps.tasks.apigen.utils import (
+ execute_from_response,
+ load_module_from_path,
+)
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from distilabel.steps.typing import StepColumns, StepOutput
+
+
+class APIGenExecutionChecker(Step):
+ """Executes the generated function calls.
+
+ This step checks if a given answer from a model as generated by `APIGenGenerator`
+ can be executed against the given library (given by `libpath`, which is a string
+ pointing to a python .py file with functions).
+
+ Attributes:
+ libpath: The path to the library where we will retrieve the functions.
+ It can also point to a folder with the functions. In this case, the folder
+ layout should be a folder with .py files, each containing a single function,
+ the name of the function being the same as the filename.
+ check_is_dangerous: Bool to exclude some potentially dangerous functions, it contains
+ some heuristics found while testing. This functions can run subprocesses, deal with
+ the OS, or have other potentially dangerous operations. Defaults to True.
+
+ Input columns:
+ - answers (`str`): List with arguments to be passed to the function,
+ dumped as a string from a list of dictionaries. Should be loaded using
+ `json.loads`.
+
+ Output columns:
+ - keep_row_after_execution_check (`bool`): Whether the function should be kept or not.
+ - execution_result (`str`): The result from executing the function.
+
+ Categories:
+ - filtering
+ - execution
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+ Execute a function from a given library with the answer from an LLM:
+
+ ```python
+ from distilabel.steps.tasks import APIGenExecutionChecker
+
+ # For the libpath you can use as an example the file at the tests folder:
+ # ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py
+ task = APIGenExecutionChecker(
+ libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py",
+ )
+ task.load()
+
+ res = next(
+ task.process(
+ [
+ {
+ "answers": [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "final_velocity",
+ }
+ ],
+ }
+ ]
+ )
+ )
+ res
+ #[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}]
+ ```
+ """
+
+ libpath: str = Field(
+ default=...,
+ description=(
+ "The path to the library where we will retrieve the functions, "
+ "or a folder with python files named the same as the functions they contain.",
+ ),
+ )
+ check_is_dangerous: bool = Field(
+ default=True,
+ description=(
+ "Bool to exclude some potentially dangerous functions, it contains "
+ "some heuristics found while testing. This functions can run subprocesses, "
+ "deal with the OS, or have other potentially dangerous operations.",
+ ),
+ )
+
+ _toolbox: Union["ModuleType", None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the library where the functions will be extracted from."""
+ super().load()
+ if Path(self.libpath).suffix == ".py":
+ self._toolbox = load_module_from_path(self.libpath)
+
+ def unload(self) -> None:
+ self._toolbox = None
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task are those found in the original dataset."""
+ return ["answers"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The outputs are the columns required by `APIGenGenerator` task."""
+ return ["keep_row_after_execution_check", "execution_result"]
+
+ def _get_function(self, function_name: str) -> Callable:
+ """Retrieves the function from the toolbox.
+
+ Args:
+ function_name: The name of the function to retrieve.
+
+ Returns:
+ Callable: The function to be executed.
+ """
+ if self._toolbox:
+ return getattr(self._toolbox, function_name, None)
+ try:
+ toolbox = load_module_from_path(
+ str(Path(self.libpath) / f"{function_name}.py")
+ )
+ return getattr(toolbox, function_name, None)
+ except FileNotFoundError:
+ return None
+ except Exception as e:
+ self._logger.warning(f"Error loading function '{function_name}': {e}")
+ return None
+
+ def _is_dangerous(self, function: Callable) -> bool:
+ """Checks if a function is dangerous to remove it.
+ Contains a list of heuristics to avoid executing possibly dangerous functions.
+ """
+ source_code = inspect.getsource(function)
+ # We don't want to execute functions that use subprocess
+ if (
+ ("subprocess." in source_code)
+ or ("os.system(" in source_code)
+ or ("input(" in source_code)
+ # Avoiding threading
+ or ("threading.Thread(" in source_code)
+ or ("exec(" in source_code)
+ # Avoiding argparse (not sure why)
+ or ("argparse.ArgumentParser(" in source_code)
+ # Avoiding logging changing the levels to not mess with the logs
+ or (".setLevel(" in source_code)
+ # Don't run a test battery
+ or ("unittest.main(" in source_code)
+ # Avoid exiting the program
+ or ("sys.exit(" in source_code)
+ or ("exit(" in source_code)
+ or ("raise SystemExit(" in source_code)
+ or ("multiprocessing.Pool(" in source_code)
+ ):
+ return True
+ return False
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """Checks the answer to see if it can be executed.
+ Captures the possible errors and returns them.
+
+ If a single example is provided, it is copied to avoid raising an error.
+
+ Args:
+ inputs: A list of dictionaries with the input data.
+
+ Yields:
+ A list of dictionaries with the output data.
+ """
+ for input in inputs:
+ output = []
+ if input["answers"]:
+ answers = json.loads(input["answers"])
+ else:
+ input.update(
+ **{
+ "keep_row_after_execution_check": False,
+ "execution_result": ["No answers were provided."],
+ }
+ )
+ continue
+ for answer in answers:
+ if answer is None:
+ output.append(
+ {
+ "keep": False,
+ "execution_result": "Nothing was generated for this answer.",
+ }
+ )
+ continue
+
+ function_name = answer.get("name", None)
+ arguments = answer.get("arguments", None)
+
+ self._logger.debug(
+ f"Executing function '{function_name}' with arguments: {arguments}"
+ )
+ function = self._get_function(function_name)
+
+ if self.check_is_dangerous:
+ if function and self._is_dangerous(function):
+ function = None
+
+ if function is None:
+ output.append(
+ {
+ "keep": False,
+ "execution_result": f"Function '{function_name}' not found.",
+ }
+ )
+ else:
+ execution = execute_from_response(function, arguments)
+ output.append(
+ {
+ "keep": execution["keep"],
+ "execution_result": execution["execution_result"],
+ }
+ )
+ # We only consider a good response if all the answers were executed successfully,
+ # but keep the reasons for further review if needed.
+ input.update(
+ **{
+ "keep_row_after_execution_check": all(
+ o["keep"] is True for o in output
+ ),
+ "execution_result": [o["execution_result"] for o in output],
+ }
+ )
+
+ yield inputs
diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py
new file mode 100644
index 0000000000..c1c691e378
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/generator.py
@@ -0,0 +1,448 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+import json
+import random
+from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Union
+
+import orjson
+from jinja2 import Template
+from pydantic import PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks.apigen.utils import remove_fences
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT_API_GEN: Final[str] = """\
+You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.
+
+Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.
+
+Ensure the query:
+- Is clear and concise
+- Demonstrates typical use cases
+- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words
+- Across a variety level of difficulties, ranging from beginner and advanced use cases
+- The corresponding result's parameter types and ranges match with the function's descriptions
+
+Ensure the answer:
+- Is a list of function calls in JSON format
+- The length of the answer list should be equal to the number of requests in the query
+- Can solve all the requests in the query effectively"""
+
+
+class APIGenGenerator(Task):
+ """Generate queries and answers for the given functions in JSON format.
+
+ The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
+ verifiable and diverse function-calling datasets. The task generates a set of diverse queries
+ and corresponding answers for the given functions in JSON format.
+
+ Attributes:
+ system_prompt: The system prompt to guide the user in the generation of queries and answers.
+ use_tools: Whether to use the tools available in the prompt to generate the queries and answers.
+ In case the tools are given in the input, they will be added to the prompt.
+ number: The number of queries to generate. It can be a list, where each number will be
+ chosen randomly, or a dictionary with the number of queries and the probability of each.
+ I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs.
+ It corresponds to the number of parallel queries to generate.
+ use_default_structured_output: Whether to use the default structured output or not.
+
+ Input columns:
+ - examples (`str`): Examples used as few shots to guide the model.
+ - func_name (`str`): Name for the function to generate.
+ - func_desc (`str`): Description of what the function should do.
+ - tools (`str`): JSON formatted string containing the tool representation of the function.
+
+ Output columns:
+ - query (`str`): The list of queries.
+ - answers (`str`): JSON formatted string with the list of answers, containing the info as
+ a dictionary to be passed to the functions.
+
+ Categories:
+ - text-generation
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+ Generate without structured output (original implementation):
+
+ ```python
+ from distilabel.steps.tasks import ApiGenGenerator
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ apigen = ApiGenGenerator(
+ use_default_structured_output=False,
+ llm=llm
+ )
+ apigen.load()
+
+ res = next(
+ apigen.process(
+ [
+ {
+ "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ "func_name": "getrandommovie",
+ "func_desc": "Returns a list of random movies from a database by calling an external API."
+ }
+ ]
+ )
+ )
+ res
+ # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ # 'number': 1,
+ # 'func_name': 'getrandommovie',
+ # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
+ # 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?',
+ # 'Give me 5 random movie suggestions from your database to plan my weekend.'],
+ # 'answers': [[{'name': 'getrandommovie', 'arguments': {}}],
+ # [{'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}}]],
+ # 'raw_input_api_gen_generator_0': [{'role': 'system',
+ # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
+ # {'role': 'user',
+ # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n "query": "The generated query.",\n "answers": [\n {\n "name": "api_name",\n "arguments": {\n "arg_name": "value"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Generate with structured output:
+
+ ```python
+ from distilabel.steps.tasks import ApiGenGenerator
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ apigen = ApiGenGenerator(
+ use_default_structured_output=True,
+ llm=llm
+ )
+ apigen.load()
+
+ res_struct = next(
+ apigen.process(
+ [
+ {
+ "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ "func_name": "getrandommovie",
+ "func_desc": "Returns a list of random movies from a database by calling an external API."
+ }
+ ]
+ )
+ )
+ res_struct
+ # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ # 'number': 1,
+ # 'func_name': 'getrandommovie',
+ # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
+ # 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?",
+ # "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"],
+ # 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}],
+ # [{'arguments': {}, 'name': 'getrandommovie'}]],
+ # 'raw_input_api_gen_generator_0': [{'role': 'system',
+ # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
+ # {'role': 'user',
+ # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT_API_GEN
+ use_default_structured_output: bool = False
+ number: Union[int, List[int], Dict[int, float]] = 1
+ use_tools: bool = True
+
+ _number: Union[int, None] = PrivateAttr(None)
+ _fn_parallel_queries: Union[Callable[[], str], None] = PrivateAttr(None)
+ _format_inst: Union[str, None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the template for the generator prompt."""
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "apigen"
+ / "generator.jinja2"
+ )
+ self._template = Template(open(_path).read())
+ self._format_inst = self._set_format_inst()
+
+ def _parallel_queries(self, number: int) -> Callable[[int], str]:
+ """Prepares the function to update the parallel queries guide in the prompt.
+
+ Raises:
+ ValueError: if `is_parallel` is not a boolean or a list of floats.
+
+ Returns:
+ The function to generate the parallel queries guide.
+ """
+ if number > 1:
+ return (
+ "It can contain multiple parallel queries in natural language for the given functions. "
+ "They could use either the same function with different arguments or different functions.\n"
+ )
+ return ""
+
+ def _get_number(self) -> int:
+ """Generates the number of queries to generate in a single call.
+ The number must be set to `_number` to avoid changing the original value
+ when calling `_default_error`.
+ """
+ if isinstance(self.number, list):
+ self._number = random.choice(self.number)
+ elif isinstance(self.number, dict):
+ self._number = random.choices(
+ list(self.number.keys()), list(self.number.values())
+ )[0]
+ else:
+ self._number = self.number
+ return self._number
+
+ def _set_format_inst(self) -> str:
+ """Prepares the function to generate the formatted instructions for the prompt.
+
+ If the default structured output is used, returns an empty string because nothing
+ else is needed, otherwise, returns the original addition to the prompt to guide the model
+ to generate a formatted JSON.
+ """
+ return (
+ "\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n"
+ "```\n"
+ "[\n"
+ " {\n"
+ ' "query": "The generated query.",\n'
+ ' "answers": [\n'
+ " {\n"
+ ' "name": "api_name",\n'
+ ' "arguments": {\n'
+ ' "arg_name": "value"\n'
+ " ... (more arguments as required)\n"
+ " }\n"
+ " },\n"
+ " ... (more API calls as required)\n"
+ " ]\n"
+ " }\n"
+ "]\n"
+ "```\n"
+ )
+
+ def _get_func_desc(self, input: Dict[str, Any]) -> str:
+ """If available and required, will use the info from the tools in the
+ prompt for extra information. Otherwise will use jut the function description.
+ """
+ if not self.use_tools:
+ return input["func_desc"]
+ extra = "" # Extra information from the tools (if available will be added)
+ if "tools" in input:
+ extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}"
+ return input["func_desc"] + extra
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task."""
+ return {
+ "examples": True,
+ "func_name": True,
+ "func_desc": True,
+ "tools": False,
+ }
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType`."""
+ number = self._get_number()
+ parallel_queries = self._parallel_queries(number)
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ examples=input["examples"],
+ parallel_queries=parallel_queries,
+ number=number,
+ func_name=input["func_name"],
+ func_desc=self._get_func_desc(input),
+ format_inst=self._format_inst,
+ ),
+ },
+ ]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The output for the task are the queries and corresponding answers."""
+ return ["query", "answers", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the queries and answers pairs.
+ The answers are an array of answers corresponding to the query.
+ Each answer is represented as an object with the following properties:
+ - name (string): The name of the tool used to generate the answer.
+ - arguments (object): An object representing the arguments passed to the tool to generate the answer.
+ Each argument is represented as a key-value pair, where the key is the parameter name and the
+ value is the corresponding value.
+ """
+ if output is None:
+ return self._default_error(input)
+
+ if not self.use_default_structured_output:
+ output = remove_fences(output)
+
+ try:
+ pairs = orjson.loads(output)
+ except orjson.JSONDecodeError:
+ return self._default_error(input)
+
+ pairs = pairs["pairs"] if self.use_default_structured_output else pairs
+
+ return self._format_output(pairs, input)
+
+ def _format_output(
+ self, pairs: Dict[str, Any], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Parses the response, returning a dictionary with queries and answers.
+
+ Args:
+ pairs: The parsed dictionary from the LLM's output.
+ input: The input from the `LLM`.
+
+ Returns:
+ Formatted output, where the `queries` are a list of strings, and the `answers`
+ are a list of objects.
+ """
+ try:
+ input.update(
+ **{
+ "query": pairs[0]["query"],
+ "answers": json.dumps(pairs[0]["answers"]),
+ }
+ )
+ return input
+ except Exception as e:
+ self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'")
+ return self._default_error(input)
+
+ def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """Returns a default error output, to fill the responses in case of failure."""
+ input.update(
+ **{
+ "query": None,
+ "answers": json.dumps([None] * self._number),
+ }
+ )
+ return input
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from typing import Dict, List
+ from pydantic import BaseModel
+
+
+ class Answer(BaseModel):
+ name: str
+ arguments: Dict[str, str]
+
+ class QueryAnswer(BaseModel):
+ query: str
+ answers: List[Answer]
+
+ class QueryAnswerPairs(BaseModel):
+ pairs: List[QueryAnswer]
+
+ json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "$defs": {
+ "Answer": {
+ "properties": {
+ "name": {"title": "Name", "type": "string"},
+ "arguments": {
+ "additionalProperties": {"type": "string"},
+ "title": "Arguments",
+ "type": "object",
+ },
+ },
+ "required": ["name", "arguments"],
+ "title": "Answer",
+ "type": "object",
+ },
+ "QueryAnswer": {
+ "properties": {
+ "query": {"title": "Query", "type": "string"},
+ "answers": {
+ "items": {"$ref": "#/$defs/Answer"},
+ "title": "Answers",
+ "type": "array",
+ },
+ },
+ "required": ["query", "answers"],
+ "title": "QueryAnswer",
+ "type": "object",
+ },
+ },
+ "properties": {
+ "pairs": {
+ "items": {"$ref": "#/$defs/QueryAnswer"},
+ "title": "Pairs",
+ "type": "array",
+ }
+ },
+ "required": ["pairs"],
+ "title": "QueryAnswerPairs",
+ "type": "object",
+ }
diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py
new file mode 100644
index 0000000000..5ec7cdc57d
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py
@@ -0,0 +1,308 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+from typing import TYPE_CHECKING, Any, Dict, Final, Union
+
+import orjson
+from jinja2 import Template
+from pydantic import PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks.apigen.utils import remove_fences
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\
+As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.
+These function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.
+
+Do not pass if:
+1. The function call does not align with the query’s objective, or the input arguments appear incorrect.
+2. The function call and arguments are not properly chosen from the available functions.
+3. The number of function calls does not correspond to the user’s intentions.
+4. The execution results are irrelevant and do not match the function’s purpose.
+5. The execution results contain errors or reflect that the function calls were not executed successfully.
+""".rstrip()
+
+
+class APIGenSemanticChecker(Task):
+ r"""Generate queries and answers for the given functions in JSON format.
+
+ The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
+ verifiable and diverse function-calling datasets. The task generates a set of diverse queries
+ and corresponding answers for the given functions in JSON format.
+
+ Attributes:
+ system_prompt: System prompt for the task. Has a default one.
+ exclude_failed_execution: Whether to exclude failed executions (won't run on those
+ rows that have a False in `keep_row_after_execution_check` column, which
+ comes from running `APIGenExecutionChecker`). Defaults to True.
+
+ Input columns:
+ - func_desc (`str`): Description of what the function should do.
+ - query (`str`): Instruction from the user.
+ - answers (`str`): JSON encoded list with arguments to be passed to the function/API.
+ Should be loaded using `json.loads`.
+ - execution_result (`str`): Result of the function/API executed.
+
+ Output columns:
+ - thought (`str`): Reasoning for the output on whether to keep this output or not.
+ - keep_row_after_semantic_check (`bool`): True or False, can be used to filter
+ afterwards.
+
+ Categories:
+ - filtering
+ - text-generation
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+
+ Semantic checker for generated function calls (original implementation):
+
+ ```python
+ from distilabel.steps.tasks import APIGenSemanticChecker
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ semantic_checker = APIGenSemanticChecker(
+ use_default_structured_output=False,
+ llm=llm
+ )
+ semantic_checker.load()
+
+ res = next(
+ semantic_checker.process(
+ [
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
+ "execution_result": "The Maine Coon is a big and hairy breed of cat",
+ }
+ ]
+ )
+ )
+ res
+ # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
+ # 'query': 'What information can be obtained about the Maine Coon cat breed?',
+ # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
+ # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
+ # 'thought': '',
+ # 'keep_row_after_semantic_check': True,
+ # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
+ # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
+ # {'role': 'user',
+ # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n "thought": "Concisely describe your reasoning here",\n "pass": "yes" or "no"\n}\n```\n'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Semantic checker for generated function calls (structured output):
+
+ ```python
+ from distilabel.steps.tasks import APIGenSemanticChecker
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ semantic_checker = APIGenSemanticChecker(
+ use_default_structured_output=True,
+ llm=llm
+ )
+ semantic_checker.load()
+
+ res = next(
+ semantic_checker.process(
+ [
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
+ "execution_result": "The Maine Coon is a big and hairy breed of cat",
+ }
+ ]
+ )
+ )
+ res
+ # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
+ # 'query': 'What information can be obtained about the Maine Coon cat breed?',
+ # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
+ # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
+ # 'keep_row_after_semantic_check': True,
+ # 'thought': '',
+ # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
+ # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
+ # {'role': 'user',
+ # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT_SEMANTIC_CHECKER
+ use_default_structured_output: bool = False
+
+ _format_inst: Union[str, None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the template for the generator prompt."""
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "apigen"
+ / "semantic_checker.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+ self._format_inst = self._set_format_inst()
+
+ def _set_format_inst(self) -> str:
+ """Prepares the function to generate the formatted instructions for the prompt.
+
+ If the default structured output is used, returns an empty string because nothing
+ else is needed, otherwise, returns the original addition to the prompt to guide the model
+ to generate a formatted JSON.
+ """
+ return (
+ "\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n"
+ "```\n"
+ "{\n"
+ ' "thought": "Concisely describe your reasoning here",\n'
+ ' "passes": "yes" or "no"\n'
+ "}\n"
+ "```\n"
+ )
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task."""
+ return {
+ "func_desc": True,
+ "query": True,
+ "answers": True,
+ "execution_result": True,
+ "keep_row_after_execution_check": True,
+ }
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType`."""
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ func_desc=input["func_desc"],
+ query=input["query"] or "",
+ func_call=input["answers"] or "",
+ execution_result=input["execution_result"],
+ format_inst=self._format_inst,
+ ),
+ },
+ ]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The output for the task are the queries and corresponding answers."""
+ return ["keep_row_after_semantic_check", "thought"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the queries and answers pairs.
+ The answers are an array of answers corresponding to the query.
+ Each answer is represented as an object with the following properties:
+ - name (string): The name of the tool used to generate the answer.
+ - arguments (object): An object representing the arguments passed to the tool to generate the answer.
+ Each argument is represented as a key-value pair, where the key is the parameter name and the
+ value is the corresponding value.
+ """
+ if output is None:
+ return self._default_error(input)
+
+ output = remove_fences(output)
+
+ try:
+ result = orjson.loads(output)
+ # Update the column name and change to bool
+ result["keep_row_after_semantic_check"] = (
+ result.pop("passes").lower() == "yes"
+ )
+ input.update(**result)
+ return input
+ except orjson.JSONDecodeError:
+ return self._default_error(input)
+
+ def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """Default error message for the task."""
+ input.update({"thought": None, "keep_row_after_semantic_check": None})
+ return input
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from typing import Literal
+ from pydantic import BaseModel
+ import json
+
+ class Checker(BaseModel):
+ thought: str
+ passes: Literal["yes", "no"]
+
+ json.dumps(Checker.model_json_schema(), indent=4)
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "properties": {
+ "thought": {"title": "Thought", "type": "string"},
+ "passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"},
+ },
+ "required": ["thought", "passes"],
+ "title": "Checker",
+ "type": "object",
+ }
diff --git a/src/distilabel/steps/tasks/apigen/utils.py b/src/distilabel/steps/tasks/apigen/utils.py
new file mode 100644
index 0000000000..85ff0b764c
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/utils.py
@@ -0,0 +1,194 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+import re
+import signal
+from typing import TYPE_CHECKING, Any, Callable, Dict, TypedDict, Union
+
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from distilabel.steps.typing import StepColumns, StepOutput
+
+
+class PrepareExamples(Step):
+ r"""Helper step to create examples from `query` and `answers` pairs used as Few Shots in APIGen.
+
+ Attributes:
+ template (str): The template to format the examples.
+
+ Input columns:
+ - query (`str`): The query to generate examples from.
+ - answers (`str`): The answers to the query.
+
+ Output columns:
+ - examples (`str`): The formatted examples.
+
+ Categories:
+ - format
+
+ Examples:
+ Generate examples for APIGen:
+
+ ```python
+ from distilabel.steps.tasks.apigen.utils import PrepareExamples
+
+ prepare_examples = PrepareExamples()
+ result = next(prepare_examples.process(
+ [
+ {
+ "query": ['I need the area of circles with radius 2.5, 5, and 7.5 inches, please.', 'Can you provide the current locations of buses and trolleys on route 12?'],
+ "answers": ['[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]', '[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]']
+ }
+ ]
+ )
+ # result
+ # [{'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}, {'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}]
+ ```
+ """
+
+ template: str = "## Query:\n{query}\n## Answers:\n{answers}"
+
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["query", "answers"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["examples"]
+
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """The process prepares the data for the `APIGenGenerator` task.
+
+ If a single example is provided, it is copied to avoid raising an error.
+
+ Args:
+ inputs: A list of dictionaries with the input data.
+
+ Yields:
+ A list of dictionaries with the output data.
+ """
+ outputs = []
+ for input in inputs:
+ example_list = []
+ for query, answers in zip(input["query"], input["answers"]):
+ example_list.append(self.template.format(query=query, answers=answers))
+ outputs.append({"examples": "\n\n".join(example_list)})
+
+ yield outputs
+
+
+def load_module_from_path(path: str) -> "ModuleType":
+ """Loads a python module from a given path.
+
+ Args:
+ path: Path pointing to the module.
+
+ Returns:
+ ModuleType
+
+ Example:
+ ```python
+ path = "/path/to/module.py"
+ module = load_module_from_path(path)
+ # And you can load functions from the module like this:
+ function = getattr(module, "function_name")
+ function(*args, **kwargs)
+ ```
+ """
+ spec = importlib.util.spec_from_file_location("module.name", path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+class FunctionResult(TypedDict):
+ keep: bool
+ execution_result: str
+
+
+def execute_from_response(
+ function: Callable, call_answer: Union[Dict[str, Any], None]
+) -> FunctionResult:
+ """Executes a function with the given arguments as generated by `APIGenGenerator`.
+
+ Given that we cannot cast all the arguments arbitrarily, we try to evaluate them,
+ which ensures the strings can be converted to the correct type if possible (say
+ a list of lists of ints will be passed as such instead of its string representation).
+
+ Args:
+ function: A callable object.
+ call_answer: The arguments to call the function, as generated by the model.
+
+ Returns:
+ A container with the result of the execution and if the row should be kept.
+ """
+ if not function:
+ return FunctionResult(keep=False, execution_result="Function not found")
+
+ if call_answer:
+ for key, value in call_answer.items():
+ if isinstance(value, str):
+ try:
+ call_answer[key] = eval(value)
+ except Exception:
+ # Leave as is and expect the function to handle it
+ pass
+
+ try:
+ if call_answer:
+ result = run_function_with_timeout(function, 5, *call_answer.values())
+ else:
+ # There can be functions that do not require arguments
+ result = run_function_with_timeout(function, 5)
+ return FunctionResult(keep=True, execution_result=str(result))
+ except Exception as e:
+ return FunctionResult(keep=False, execution_result=str(e))
+
+
+def remove_json_fences(text: str) -> str:
+ pattern = r"^```json\n([\s\S]*)\n```$"
+ match = re.match(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1)
+ return text
+
+
+def remove_fences(text: str) -> str:
+ pattern = r"^```\n([\s\S]*)\n```$"
+ match = re.match(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1)
+ return text
+
+
+def timeout_handler(signum, frame):
+ raise TimeoutError("Function execution timed out")
+
+
+def run_function_with_timeout(function: Callable, timeout: int = 5, *args: Any) -> Any:
+ """Run a function with a timeout, to limit the total time waiting for a result."""
+ signal.signal(signal.SIGALRM, timeout_handler)
+ signal.alarm(timeout)
+
+ try:
+ result = function(*args)
+ finally:
+ # Cancel the alarm
+ signal.alarm(0)
+
+ return result
diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py
new file mode 100644
index 0000000000..d0874ed3de
--- /dev/null
+++ b/src/distilabel/steps/tasks/argilla_labeller.py
@@ -0,0 +1,614 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import warnings
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import orjson as json
+from jinja2 import Template
+from pydantic import BaseModel, Field, PrivateAttr
+from typing_extensions import override
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps.base import StepInput
+from distilabel.steps.tasks.base import Task
+
+if sys.version_info < (3, 9):
+ import importlib_resources
+else:
+ import importlib.resources as importlib_resources
+
+if TYPE_CHECKING:
+ from argilla import (
+ LabelQuestion,
+ MultiLabelQuestion,
+ RatingQuestion,
+ Record,
+ TextField,
+ TextQuestion,
+ )
+
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepOutput
+
+
+class ArgillaLabeller(Task):
+ """
+ Annotate Argilla records based on input fields, example records and question settings.
+
+ This task is designed to facilitate the annotation of Argilla records by leveraging a pre-trained LLM.
+ It uses a system prompt that guides the LLM to understand the input fields, the question type,
+ and the question settings. The task then formats the input data and generates a response based on the question.
+ The response is validated against the question's value model, and the final suggestion is prepared for annotation.
+
+ Attributes:
+ _template: a Jinja2 template used to format the input for the LLM.
+
+ Input columns:
+ - record (`argilla.Record`): The record to be annotated.
+ - fields (`Optional[List[Dict[str, Any]]]`): The list of field settings for the input fields.
+ - question (`Optional[Dict[str, Any]]`): The question settings for the question to be answered.
+ - example_records (`Optional[List[Dict[str, Any]]]`): The few shot example records with responses to be used to answer the question.
+ - guidelines (`Optional[str]`): The guidelines for the annotation task.
+
+ Output columns:
+ - suggestion (`Dict[str, Any]`): The final suggestion for annotation.
+
+ Categories:
+ - text-classification
+ - scorer
+ - text-generation
+
+ References:
+ - [`Argilla: Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets`](https://github.com/argilla-io/argilla/)
+
+ Examples:
+ Annotate a record with the same dataset and question:
+
+ ```python
+ import argilla as rg
+ from argilla import Suggestion
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Get information from Argilla dataset definition
+ dataset = rg.Dataset("my_dataset")
+ pending_records_filter = rg.Filter(("status", "==", "pending"))
+ completed_records_filter = rg.Filter(("status", "==", "completed"))
+ pending_records = list(
+ dataset.records(
+ query=rg.Query(filter=pending_records_filter),
+ limit=5,
+ )
+ )
+ example_records = list(
+ dataset.records(
+ query=rg.Query(filter=completed_records_filter),
+ limit=5,
+ )
+ )
+ field = dataset.settings.fields["text"]
+ question = dataset.settings.questions["label"]
+
+ # Initialize the labeller with the model and fields
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ fields=[field],
+ question=question,
+ example_records=example_records,
+ guidelines=dataset.guidelines
+ )
+ labeller.load()
+
+ # Process the pending records
+ result = next(
+ labeller.process(
+ [
+ {
+ "record": record
+ } for record in pending_records
+ ]
+ )
+ )
+
+ # Add the suggestions to the records
+ for record, suggestion in zip(pending_records, result):
+ record.suggestions.add(Suggestion(**suggestion["suggestion"]))
+
+ # Log the updated records
+ dataset.records.log(pending_records)
+ ```
+
+ Annotate a record with alternating datasets and questions:
+
+ ```python
+ import argilla as rg
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Get information from Argilla dataset definition
+ dataset = rg.Dataset("my_dataset")
+ field = dataset.settings.fields["text"]
+ question = dataset.settings.questions["label"]
+ question2 = dataset.settings.questions["label2"]
+
+ # Initialize the labeller with the model and fields
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
+ labeller.load()
+
+ # Process the record
+ record = next(dataset.records())
+ result = next(
+ labeller.process(
+ [
+ {
+ "record": record,
+ "fields": [field],
+ "question": question,
+ },
+ {
+ "record": record,
+ "fields": [field],
+ "question": question2,
+ }
+ ]
+ )
+ )
+
+ # Add the suggestions to the record
+ for suggestion in result:
+ record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))
+
+ # Log the updated record
+ dataset.records.log([record])
+ ```
+
+ Overwrite default prompts and instructions:
+
+ ```python
+ import argilla as rg
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Overwrite default prompts and instructions
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
+ question_to_label_instruction={
+ "label_selection": "Select the appropriate label from the list of provided labels.",
+ "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
+ "text": "Provide a text response to the question.",
+ "rating": "Provide a rating for the question.",
+ },
+ )
+ labeller.load()
+ ```
+ """
+
+ system_prompt: str = (
+ "You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
+ "You are given input fields and a question. "
+ "You should create a valid JSON object as an answer to the question based on the input fields. "
+ "1. Understand the input fields and optional guidelines. "
+ "2. Understand the question type and the question settings. "
+ "3. Reason through your response step-by-step. "
+ "4. Provide a valid JSON object as an answer to the question."
+ )
+ question_to_label_instruction: Dict[str, str] = {
+ "label_selection": "Select the appropriate label from the list of provided labels.",
+ "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
+ "text": "Provide a text response to the question.",
+ "rating": "Provide a rating for the question.",
+ }
+ example_records: Optional[
+ RuntimeParameter[Union[List[Union[Dict[str, Any], BaseModel]], None]]
+ ] = Field(
+ default=None,
+ description="The few shot serialized example records or `BaseModel`s with responses to be used to answer the question.",
+ )
+ fields: Optional[
+ RuntimeParameter[Union[List[Union[BaseModel, Dict[str, Any]]], None]]
+ ] = Field(
+ default=None,
+ description="The field serialized field settings or `BaseModel` for the fields to be used to answer the question.",
+ )
+ question: Optional[
+ RuntimeParameter[
+ Union[
+ Dict[str, Any],
+ BaseModel,
+ None,
+ ]
+ ]
+ ] = Field(
+ default=None,
+ description="The question serialized question settings or `BaseModel` for the question to be answered.",
+ )
+ guidelines: Optional[RuntimeParameter[str]] = Field(
+ default=None,
+ description="The guidelines for the annotation task.",
+ )
+
+ _template: Union[Template, None] = PrivateAttr(...)
+ _client: Optional[Any] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the Jinja2 template."""
+ super().load()
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "argillalabeller.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def inputs(self) -> Dict[str, bool]:
+ return {
+ "record": True,
+ "fields": False,
+ "question": False,
+ "example_records": False,
+ "guidelines": False,
+ }
+
+ def _format_record(
+ self, record: Dict[str, Any], fields: List[Dict[str, Any]]
+ ) -> str:
+ """Format the record fields into a string.
+
+ Args:
+ record (Dict[str, Any]): The record to format.
+ fields (List[Dict[str, Any]]): The fields to format.
+
+ Returns:
+ str: The formatted record fields.
+ """
+ output = []
+ for field in fields:
+ if title := field.get("title"):
+ output.append(f"title: {title}")
+ if description := field.get("description"):
+ output.append(f"description: {description}")
+ output.append(record.get("fields", {}).get(field.get("name", "")))
+ return "\n".join(output)
+
+ def _get_label_instruction(self, question: Dict[str, Any]) -> str:
+ """Get the label instruction for the question.
+
+ Args:
+ question (Dict[str, Any]): The question to get the label instruction for.
+
+ Returns:
+ str: The label instruction for the question.
+ """
+ question_type = question["settings"]["type"]
+ return self.question_to_label_instruction[question_type]
+
+ def _format_question(self, question: Dict[str, Any]) -> str:
+ """Format the question settings into a string.
+
+ Args:
+ question (Dict[str, Any]): The question to format.
+
+ Returns:
+ str: The formatted question.
+ """
+ output = [
+ f"title: {question.get('title', '')}",
+ f"description: {question.get('description', '')}",
+ f"label_instruction: {self._get_label_instruction(question)}",
+ ]
+ settings = question.get("settings", {})
+ if "options" in settings:
+ output.append(
+ f"labels: {[option['value'] for option in settings.get('options', [])]}"
+ )
+ return "\n".join(output)
+
+ def _format_example_records(
+ self,
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ question: Dict[str, Any],
+ ) -> str:
+ """Format the example records into a string.
+
+ Args:
+ records (List[Dict[str, Any]]): The records to format.
+ fields (List[Dict[str, Any]]): The fields to format.
+ question (Dict[str, Any]): The question to format.
+
+ Returns:
+ str: The formatted example records.
+ """
+ base = []
+ for record in records:
+ responses = record.get("responses", {})
+ if responses.get(question["name"]):
+ base.append(self._format_record(record, fields))
+ value = responses[question["name"]][0]["value"]
+ formatted_value = self._assign_value_to_question_value_model(
+ value, question
+ )
+ base.append(f"Response: {formatted_value}")
+ base.append("")
+ else:
+ warnings.warn(
+ f"Record {record} has no response for question {question['name']}. Skipping example record.",
+ stacklevel=2,
+ )
+ return "\n".join(base)
+
+ def format_input(
+ self,
+ input: Dict[
+ str,
+ Union[
+ Dict[str, Any],
+ "Record",
+ "TextField",
+ "MultiLabelQuestion",
+ "LabelQuestion",
+ "RatingQuestion",
+ "TextQuestion",
+ ],
+ ],
+ ) -> "ChatType":
+ """Format the input into a chat message.
+
+ Args:
+ input: The input to format.
+
+ Returns:
+ The formatted chat message.
+
+ Raises:
+ ValueError: If question or fields are not provided.
+ """
+ input_keys = list(self.inputs.keys())
+ record = input[input_keys[0]]
+ fields = input.get(input_keys[1], self.fields)
+ question = input.get(input_keys[2], self.question)
+ examples = input.get(input_keys[3], self.example_records)
+ guidelines = input.get(input_keys[4], self.guidelines)
+
+ if question is None:
+ raise ValueError("Question must be provided.")
+ if fields is None or any(field is None for field in fields):
+ raise ValueError("Fields must be provided.")
+
+ record = record.to_dict() if not isinstance(record, dict) else record
+ question = question.serialize() if not isinstance(question, dict) else question
+ fields = [
+ field.serialize() if not isinstance(field, dict) else field
+ for field in fields
+ ]
+ examples = (
+ [
+ example.to_dict() if not isinstance(example, dict) else example
+ for example in examples
+ ]
+ if examples
+ else None
+ )
+
+ formatted_fields = self._format_record(record, fields)
+ formatted_question = self._format_question(question)
+ formatted_examples = (
+ self._format_example_records(examples, fields, question)
+ if examples
+ else False
+ )
+
+ prompt = self._template.render(
+ fields=formatted_fields,
+ question=formatted_question,
+ examples=formatted_examples,
+ guidelines=guidelines,
+ )
+
+ messages = []
+ if self.system_prompt:
+ messages.append({"role": "system", "content": self.system_prompt})
+ messages.append({"role": "user", "content": prompt})
+ return messages
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["suggestion"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Format the output into a dictionary.
+
+ Args:
+ output (Union[str, None]): The output to format.
+ input (Dict[str, Any]): The input to format.
+
+ Returns:
+ Dict[str, Any]: The formatted output.
+ """
+ from argilla import Suggestion
+
+ question: Union[
+ Any,
+ Dict[str, Any],
+ LabelQuestion,
+ MultiLabelQuestion,
+ RatingQuestion,
+ TextQuestion,
+ None,
+ ] = input.get(list(self.inputs.keys())[2], self.question) or self.question
+ question = question.serialize() if not isinstance(question, dict) else question
+ model = self._get_pydantic_model_of_structured_output(question)
+ validated_output = model(**json.loads(output))
+ value = self._get_value_from_question_value_model(validated_output)
+ suggestion = Suggestion(
+ value=value,
+ question_name=question["name"],
+ type="model",
+ agent=self.llm.model_name,
+ ).serialize()
+ return {
+ self.outputs[0]: {
+ k: v
+ for k, v in suggestion.items()
+ if k in ["value", "question_name", "type", "agent"]
+ }
+ }
+
+ def _set_llm_structured_output_for_question(self, question: Dict[str, Any]) -> None:
+ runtime_parameters = self.llm._runtime_parameters
+ runtime_parameters.update(
+ {
+ "structured_output": {
+ "format": "json",
+ "schema": self._get_pydantic_model_of_structured_output(question),
+ },
+ }
+ )
+ self.llm.set_runtime_parameters(runtime_parameters)
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """Process the input through the task.
+
+ Args:
+ inputs (StepInput): The input to process.
+
+ Returns:
+ StepOutput: The output of the task.
+ """
+
+ question_list = [input.get("question", self.question) for input in inputs]
+ fields_list = [input.get("fields", self.fields) for input in inputs]
+ # check if any field for the field in fields is None
+ for fields in fields_list:
+ if any(field is None for field in fields):
+ raise ValueError(
+ "Fields must be provided during init or through `process` method."
+ )
+ # check if any question is None
+ if any(question is None for question in question_list):
+ raise ValueError(
+ "Question must be provided during init or through `process` method."
+ )
+ question_list = [
+ question.serialize() if not isinstance(question, dict) else question
+ for question in question_list
+ ]
+ if not all(question == question_list[0] for question in question_list):
+ warnings.warn(
+ "Not all questions are the same. Processing each question separately by setting the structured output for each question. This may impact performance.",
+ stacklevel=2,
+ )
+ for input, question in zip(inputs, question_list):
+ self._set_llm_structured_output_for_question(question)
+ yield from super().process([input])
+ else:
+ question = question_list[0]
+ self._set_llm_structured_output_for_question(question)
+ yield from super().process(inputs)
+
+ def _get_value_from_question_value_model(
+ self, question_value_model: BaseModel
+ ) -> Any:
+ """Get the value from the question value model.
+
+ Args:
+ question_value_model (BaseModel): The question value model to get the value from.
+
+ Returns:
+ Any: The value from the question value model.
+ """
+ for attr in ["label", "labels", "rating", "text"]:
+ if hasattr(question_value_model, attr):
+ return getattr(question_value_model, attr)
+ raise ValueError(f"Unsupported question type: {question_value_model}")
+
+ def _assign_value_to_question_value_model(
+ self, value: Any, question: Dict[str, Any]
+ ) -> BaseModel:
+ """Assign the value to the question value model.
+
+ Args:
+ value (Any): The value to assign.
+ question (Dict[str, Any]): The question to assign the value to.
+
+ Returns:
+ BaseModel: The question value model with the assigned value.
+ """
+ question_value_model = self._get_pydantic_model_of_structured_output(question)
+ for attr in ["label", "labels", "rating", "text"]:
+ try:
+ model_dict = {attr: value}
+ question_value_model = question_value_model(**model_dict)
+ return question_value_model.model_dump_json()
+ except AttributeError:
+ pass
+ return value
+
+ def _get_pydantic_model_of_structured_output(
+ self,
+ question: Dict[str, Any],
+ ) -> BaseModel:
+ """Get the Pydantic model of the structured output.
+
+ Args:
+ question (Dict[str, Any]): The question to get the Pydantic model of the structured output for.
+
+ Returns:
+ BaseModel: The Pydantic model of the structured output.
+ """
+
+ question_type = question["settings"]["type"]
+
+ if question_type == "multi_label_selection":
+
+ class QuestionValueModel(BaseModel):
+ labels: Optional[List[str]] = Field(default_factory=list)
+
+ elif question_type == "label_selection":
+
+ class QuestionValueModel(BaseModel):
+ label: str
+
+ elif question_type == "text":
+
+ class QuestionValueModel(BaseModel):
+ text: str
+
+ elif question_type == "rating":
+
+ class QuestionValueModel(BaseModel):
+ rating: int
+ else:
+ raise ValueError(f"Unsupported question type: {question}")
+
+ return QuestionValueModel
diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py
index 7281d6dd7e..0524749e26 100644
--- a/src/distilabel/steps/tasks/base.py
+++ b/src/distilabel/steps/tasks/base.py
@@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
-from pydantic import Field
+from pydantic import Field, PrivateAttr
from typing_extensions import override
from distilabel.constants import DISTILABEL_METADATA_KEY
+from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import (
GeneratorStep,
+ GlobalStep,
Step,
StepInput,
_Step,
@@ -31,7 +34,7 @@
if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
- from distilabel.steps.tasks.typing import FormattedInput
+ from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput
@@ -60,13 +63,52 @@ class _Task(_Step, ABC):
" of the `distilabel_metadata` dictionary output column"
),
)
+ add_raw_input: RuntimeParameter[bool] = Field(
+ default=True,
+ description=(
+ "Whether to include the raw input of the LLM in the key `raw_input_`"
+ " of the `distilabel_metadata` dictionary column"
+ ),
+ )
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
+ use_default_structured_output: bool = False
+
+ _can_be_used_with_offline_batch_generation: bool = PrivateAttr(False)
+
+ def model_post_init(self, __context: Any) -> None:
+ if (
+ self.llm.use_offline_batch_generation
+ and not self._can_be_used_with_offline_batch_generation
+ ):
+ raise DistilabelUserError(
+ f"`{self.__class__.__name__}` task cannot be used with offline batch generation"
+ " feature.",
+ page="sections/how_to_guides/advanced/offline-batch-generation",
+ )
+
+ super().model_post_init(__context)
+
+ @property
+ def is_global(self) -> bool:
+ """Extends the `is_global` property to return `True` if the task is using the
+ offline batch generation feature, otherwise it returns the value of the parent
+ class property. `offline_batch_generation` requires to receive all the inputs
+ at once, so for the `_BatchManager` this is a global step.
+
+ Returns:
+ Whether the task is a global step or not.
+ """
+ if self.llm.use_offline_batch_generation:
+ return True
+
+ return super().is_global
def load(self) -> None:
"""Loads the LLM via the `LLM.load()` method."""
super().load()
+ self._set_default_structured_output()
self.llm.load()
@override
@@ -75,6 +117,28 @@ def unload(self) -> None:
self._logger.debug("Executing task unload logic.")
self.llm.unload()
+ @override
+ def impute_step_outputs(
+ self, step_output: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ Imputes the outputs of the task in case the LLM failed to generate a response.
+ """
+ result = []
+ for row in step_output:
+ data = row.copy()
+ for output in self.get_outputs().keys():
+ data[output] = None
+ data = self._maybe_add_raw_input_output(
+ data,
+ None,
+ None,
+ add_raw_output=self.add_raw_output,
+ add_raw_input=self.add_raw_input,
+ )
+ result.append(data)
+ return result
+
@abstractmethod
def format_output(
self,
@@ -110,10 +174,12 @@ def _format_outputs(
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
try:
formatted_output = self.format_output(output, input)
- formatted_output = self._maybe_add_raw_output(
+ formatted_output = self._maybe_add_raw_input_output(
formatted_output,
output,
+ input,
add_raw_output=self.add_raw_output, # type: ignore
+ add_raw_input=self.add_raw_input, # type: ignore
)
formatted_outputs.append(formatted_output)
except Exception as e:
@@ -132,26 +198,171 @@ def _output_on_failure(
# Create a dictionary with the outputs of the task (every output set to None)
outputs = {output: None for output in self.outputs}
outputs["model_name"] = self.llm.model_name # type: ignore
- outputs = self._maybe_add_raw_output(
+ outputs = self._maybe_add_raw_input_output(
outputs,
output,
+ input,
add_raw_output=self.add_raw_output, # type: ignore
+ add_raw_input=self.add_raw_input, # type: ignore
)
return outputs
- def _maybe_add_raw_output(
+ def _maybe_add_raw_input_output(
self,
output: Dict[str, Any],
raw_output: Union[str, None],
+ input: Union[str, None],
add_raw_output: bool = True,
- ) -> Dict[str, Any]:
- """Adds the raw output of the LLM to the output dictionary if `add_raw_output` is True."""
+ add_raw_input: bool = True,
+ ):
+ """Adds the raw output and or the formatted input of the LLM to the output dictionary
+ if `add_raw_output` is True or `add_raw_input` is True.
+ """
+ meta = output.get(DISTILABEL_METADATA_KEY, {})
+
if add_raw_output:
- meta = output.get(DISTILABEL_METADATA_KEY, {})
meta[f"raw_output_{self.name}"] = raw_output
+ if add_raw_input:
+ meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None
+ if meta:
output[DISTILABEL_METADATA_KEY] = meta
+
return output
+ def _set_default_structured_output(self) -> None:
+ """Prepares the structured output to be set in the selected `LLM`.
+
+ If the method `get_structured_output` returns None (the default), there's no need
+ to set anything, as it doesn't apply.
+ If the `use_default_structured_output` and there's no previous structured output
+ set by hand, then decide the type of structured output to select depending on the
+ `LLM` provider.
+ """
+ schema = self.get_structured_output()
+ if not schema:
+ return
+
+ if self.use_default_structured_output and not self.llm.structured_output:
+ # In case the default structured output is required, we have to set it before
+ # the LLM is loaded
+ from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.llms.base import AsyncLLM
+
+ def check_dependency(module_name: str) -> None:
+ if not importlib.util.find_spec(module_name):
+ raise ImportError(
+ f"`{module_name}` is not installed and is needed for the structured generation with this LLM."
+ f" Please install it using `pip install {module_name}`."
+ )
+
+ dependency = "outlines"
+ structured_output = {"schema": schema}
+ if isinstance(self.llm, InferenceEndpointsLLM):
+ structured_output.update({"format": "json"})
+ # To determine instructor or outlines format
+ elif isinstance(self.llm, AsyncLLM) and not isinstance(
+ self.llm, InferenceEndpointsLLM
+ ):
+ dependency = "instructor"
+ structured_output.update({"format": "json"})
+
+ check_dependency(dependency)
+ self.llm.structured_output = structured_output
+
+ def get_structured_output(self) -> Union[Dict[str, Any], None]:
+ """Returns the structured output for a task that implements one by default,
+ must be overriden by subclasses of `Task`. When implemented, should be a json
+ schema that enforces the response from the LLM so that it's easier to parse.
+ """
+ return None
+
+ def _sample_input(self) -> "ChatType":
+ """Returns a sample input to be used in the `print` method.
+ Tasks that don't adhere to a format input that returns a map of the type
+ str -> str should override this method to return a sample input.
+ """
+ return self.format_input(
+ {input: f"" for input in self.inputs}
+ )
+
+ def print(self, sample_input: Optional["ChatType"] = None) -> None:
+ """Prints a sample input to the console using the `rich` library.
+ Helper method to visualize the prompt of the task.
+
+ Args:
+ sample_input: A sample input to be printed. If not provided, a default will be
+ generated using the `_sample_input` method, which can be overriden by
+ subclasses. This should correspond to the same example you could pass to
+ the `format_input` method.
+ The variables be named by default.
+
+ Examples:
+ Print the URIAL prompt:
+
+ ```python
+ from distilabel.steps.tasks import URIAL
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ urial = URIAL(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ )
+ urial.load()
+ urial.print()
+ ╭─────────────────────────────────────── Prompt: URIAL ────────────────────────────────────────╮
+ │ ╭────────────────────────────────────── User Message ───────────────────────────────────────╮ │
+ │ │ # Instruction │ │
+ │ │ │ │
+ │ │ Below is a list of conversations between a human and an AI assistant (you). │ │
+ │ │ Users place their queries under "# User:", and your responses are under "# Assistant:". │ │
+ │ │ You are a helpful, respectful, and honest assistant. │ │
+ │ │ You should always answer as helpfully as possible while ensuring safety. │ │
+ │ │ Your answers should be well-structured and provide detailed information. They should also │ │
+ │ │ have an engaging tone. │ │
+ │ │ Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, │ │
+ │ │ dangerous, or illegal content, even if it may be helpful. │ │
+ │ │ Your response must be socially responsible, and thus you can refuse to answer some │ │
+ │ │ controversial topics. │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ # User: │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ # Assistant: │ │
+ │ ╰───────────────────────────────────────────────────────────────────────────────────────────╯ │
+ ╰───────────────────────────────────────────────────────────────────────────────────────────────╯
+ ```
+ """
+ from rich.console import Console, Group
+ from rich.panel import Panel
+ from rich.text import Text
+
+ console = Console()
+ sample_input = sample_input or self._sample_input()
+
+ panels = []
+ for item in sample_input:
+ content = Text.assemble((item.get("content", ""),))
+ panel = Panel(
+ content,
+ title=f"[bold][magenta]{item.get('role', '').capitalize()} Message[/magenta][/bold]",
+ border_style="light_cyan3",
+ )
+ panels.append(panel)
+
+ # Create a group of panels
+ # Wrap the group in an outer panel
+ outer_panel = Panel(
+ Group(*panels),
+ title=f"[bold][magenta]Prompt: {type(self).__name__} [/magenta][/bold]",
+ border_style="light_cyan3",
+ expand=False,
+ )
+ console.print(outer_panel)
+
class Task(_Task, Step):
"""Task is a class that implements the `_Task` abstract class and adds the `Step`
@@ -195,7 +406,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
formatted_inputs = self._format_inputs(inputs)
# `outputs` is a list containing a list of generations per input
- outputs = self.llm.generate(
+ outputs = self.llm.generate_outputs(
inputs=formatted_inputs,
num_generations=self.num_generations, # type: ignore
**self.llm.get_generation_kwargs(), # type: ignore
@@ -222,7 +433,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
class GeneratorTask(_Task, GeneratorStep):
- """GeneratorTask is a class that implements the `_Task` abstract class and adds the
+ """`GeneratorTask` is a class that implements the `_Task` abstract class and adds the
`GeneratorStep` interface to be used as a step in the pipeline.
Attributes:
@@ -233,3 +444,12 @@ class GeneratorTask(_Task, GeneratorStep):
"""
pass
+
+
+class GlobalTask(_Task, GlobalStep):
+ """`GlobalTask` is a class that implements the `_Task` abstract class and adds the
+ `GlobalStep` interface to be used as a step in the pipeline. It's generally used in
+ combination with `LLM`s that can be used for offline batched inference.
+ """
+
+ pass
diff --git a/src/distilabel/steps/tasks/clair.py b/src/distilabel/steps/tasks/clair.py
new file mode 100644
index 0000000000..cbf189ab72
--- /dev/null
+++ b/src/distilabel/steps/tasks/clair.py
@@ -0,0 +1,199 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+from typing import TYPE_CHECKING, Any, Dict, Final, Union
+
+from jinja2 import Template
+from pydantic import PrivateAttr
+
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT: Final[str] = (
+ "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."
+)
+
+
+class CLAIR(Task):
+ r"""Contrastive Learning from AI Revisions (CLAIR).
+
+ CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting
+ preference A `preferred` A’ is much more contrastive and precise.
+
+ Input columns:
+ - task (`str`): The task or instruction.
+ - student_solution (`str`): An answer to the task that is to be revised.
+
+ Output columns:
+ - revision (`str`): The revised text.
+ - rational (`str`): The rational for the provided revision.
+ - model_name (`str`): The name of the model used to generate the revision and rational.
+
+ Categories:
+ - preference
+ - text-generation
+
+ References:
+ - [`Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment`](https://arxiv.org/abs/2408.06266v1)
+ - [`APO and CLAIR - GitHub Repository`](https://github.com/ContextualAI/CLAIR_and_APO)
+
+ Examples:
+ Create contrastive preference pairs:
+
+ ```python
+ from distilabel.steps.tasks import CLAIR
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 4096,
+ },
+ )
+ clair_task = CLAIR(llm=llm)
+
+ clair_task.load()
+
+ result = next(
+ clair_task.process(
+ [
+ {
+ "task": "How many gaps are there between the earth and the moon?",
+ "student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'
+ }
+ ]
+ )
+ )
+ # result
+ # [{'task': 'How many gaps are there between the earth and the moon?',
+ # 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
+ # 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
+ # 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.',
+ # 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
+ # 'raw_input_c_l_a_i_r_0': [{'role': 'system',
+ # 'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."},
+ # {'role': 'user',
+ # 'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Citations:
+
+ ```
+ @misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive,
+ title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment},
+ author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri},
+ year={2024},
+ eprint={2408.06266},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2408.06266},
+ }
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT
+ _template: Union[Template, None] = PrivateAttr(...)
+
+ def load(self) -> None:
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "clair.jinja2"
+ )
+ with open(_path, "r") as f:
+ self._template = Template(f.read())
+
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["task", "student_solution"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["revision", "rational", "model_name"]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType` assuming that the instruction
+ is the first interaction from the user within a conversation."""
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ task=input["task"], student_solution=input["student_solution"]
+ ),
+ },
+ ]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction-response pair.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the key `scores` containing the scores for each instruction-response pair.
+ """
+ if output is None:
+ return self._default_error()
+
+ return self._format_output(output)
+
+ def _format_output(self, output: Union[str, None]) -> Dict[str, Any]:
+ if "**Corrected Student Solution:**" in output:
+ splits = output.split("**Corrected Student Solution:**")
+ elif "{corrected_student_solution}:" in output:
+ splits = output.split("{corrected_student_solution}:")
+ elif "{corrected_student_solution}" in output:
+ splits = output.split("{corrected_student_solution}")
+ elif "**Worsened Student Solution:**" in output:
+ splits = output.split("**Worsened Student Solution:**")
+ elif "{worsened_student_solution}:" in output:
+ splits = output.split("{worsened_student_solution}:")
+ elif "{worsened_student_solution}" in output:
+ splits = output.split("{worsened_student_solution}")
+ else:
+ splits = None
+
+ # Safety check when the output doesn't follow the expected format
+ if not splits:
+ return self._default_error()
+
+ if len(splits) >= 2:
+ revision = splits[1]
+ revision = revision.strip("\n\n").strip() # noqa: B005
+
+ rational = splits[0]
+ if "{teacher_reasoning}" in rational:
+ rational = rational.split("{teacher_reasoning}")[1].strip(":").strip()
+ rational = rational.strip("\n\n").strip() # noqa: B005
+ else:
+ return self._default_error()
+ return {"revision": revision, "rational": rational}
+
+ def _default_error(self) -> Dict[str, None]:
+ return {"revision": None, "rational": None}
diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py
index 170d75e13e..401e3b760f 100644
--- a/src/distilabel/steps/tasks/complexity_scorer.py
+++ b/src/distilabel/steps/tasks/complexity_scorer.py
@@ -22,8 +22,10 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union
+import orjson
from jinja2 import Template
from pydantic import PrivateAttr
+from typing_extensions import override
from distilabel.steps.tasks.base import Task
@@ -61,7 +63,6 @@ class ComplexityScorer(Task):
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
Examples:
-
Evaluate the complexity of your instructions:
```python
@@ -86,8 +87,32 @@ class ComplexityScorer(Task):
# [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 5], 'distilabel_metadata': {'raw_output_complexity_scorer_0': 'output'}}]
```
- Citations:
+ Generate structured output with default schema:
+
+ ```python
+ from distilabel.steps.tasks import ComplexityScorer
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ scorer = ComplexityScorer(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ use_default_structured_output=use_default_structured_output
+ )
+ scorer.load()
+
+ result = next(
+ scorer.process(
+ [{"instructions": ["plain instruction", "highly complex instruction"]}]
+ )
+ )
+ # result
+ # [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 2], 'distilabel_metadata': {'raw_output_complexity_scorer_0': '{ \\n "scores": [\\n 1, \\n 2\\n ]\\n}'}}]
+ ```
+
+ Citations:
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
@@ -102,6 +127,7 @@ class ComplexityScorer(Task):
"""
_template: Union[Template, None] = PrivateAttr(...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
@@ -153,6 +179,9 @@ def format_output(
if output is None:
return {"scores": [None] * len(input["instructions"])}
+ if self.use_default_structured_output:
+ return self._format_structured_output(output, input)
+
scores = []
score_lines = output.split("\n")
for i, line in enumerate(score_lines):
@@ -162,3 +191,65 @@ def format_output(
if i == len(input["instructions"]) - 1:
break
return {"scores": scores}
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from pydantic import BaseModel
+ from typing import List
+
+ class SchemaComplexityScorer(BaseModel):
+ scores: List[int]
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "properties": {
+ "scores": {
+ "items": {"type": "integer"},
+ "title": "Scores",
+ "type": "array",
+ }
+ },
+ "required": ["scores"],
+ "title": "SchemaComplexityScorer",
+ "type": "object",
+ }
+
+ def _format_structured_output(
+ self, output: str, input: Dict[str, Any]
+ ) -> Dict[str, str]:
+ """Parses the structured response, which should correspond to a dictionary
+ with either `positive`, or `positive` and `negative` keys.
+
+ Args:
+ output: The output from the `LLM`.
+
+ Returns:
+ Formatted output.
+ """
+ try:
+ return orjson.loads(output)
+ except orjson.JSONDecodeError:
+ return {"scores": [None] * len(input["instructions"])}
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ """Returns a sample input to be used in the `print` method.
+ Tasks that don't adhere to a format input that returns a map of the type
+ str -> str should override this method to return a sample input.
+ """
+ return self.format_input(
+ {
+ "instructions": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py
index 71da271554..95f271a117 100644
--- a/src/distilabel/steps/tasks/evol_instruct/base.py
+++ b/src/distilabel/steps/tasks/evol_instruct/base.py
@@ -71,7 +71,6 @@ class EvolInstruct(Task):
- [GitHub: h2oai/h2o-wizardlm](https://github.com/h2oai/h2o-wizardlm)
Examples:
-
Evolve an instruction using an LLM:
```python
@@ -151,7 +150,6 @@ class EvolInstruct(Task):
```
Citations:
-
```
@misc{xu2023wizardlmempoweringlargelanguage,
title={WizardLM: Empowering Large Language Models to Follow Complex Instructions},
@@ -390,3 +388,9 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
):
input.update(self.format_output(instruction, answers[idx]))
yield inputs
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ self._apply_random_mutation("")
+ )
diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
index 1619db4229..a7e46b154b 100644
--- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
+++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
@@ -63,7 +63,6 @@ class EvolComplexity(EvolInstruct):
- [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244)
Examples:
-
Evolve an instruction using an LLM:
```python
@@ -86,7 +85,6 @@ class EvolComplexity(EvolInstruct):
```
Citations:
-
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
index 8749fcd8c1..f1965d9e83 100644
--- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
+++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
@@ -61,7 +61,6 @@ class EvolComplexityGenerator(EvolInstructGenerator):
- [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244)
Examples:
-
Generate evolved instructions without initial instructions:
```python
@@ -84,7 +83,6 @@ class EvolComplexityGenerator(EvolInstructGenerator):
```
Citations:
-
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py
index 1eea138a69..1f56c866a3 100644
--- a/src/distilabel/steps/tasks/evol_instruct/generator.py
+++ b/src/distilabel/steps/tasks/evol_instruct/generator.py
@@ -77,7 +77,6 @@ class EvolInstructGenerator(GeneratorTask):
- [GitHub: h2oai/h2o-wizardlm](https://github.com/h2oai/h2o-wizardlm)
Examples:
-
Generate evolved instructions without initial instructions:
```python
@@ -100,7 +99,6 @@ class EvolInstructGenerator(GeneratorTask):
```
Citations:
-
```
@misc{xu2023wizardlmempoweringlargelanguage,
title={WizardLM: Empowering Large Language Models to Follow Complex Instructions},
@@ -349,3 +347,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
],
True,
)
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ return self._apply_random_mutation(iter_no=0)[0]
diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py
index 1c0d6c4d52..5c899aa680 100644
--- a/src/distilabel/steps/tasks/evol_quality/base.py
+++ b/src/distilabel/steps/tasks/evol_quality/base.py
@@ -67,7 +67,6 @@ class EvolQuality(Task):
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
Examples:
-
Evolve the quality of the responses given a prompt:
```python
@@ -103,7 +102,6 @@ class EvolQuality(Task):
```
Citations:
-
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
@@ -273,3 +271,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
yield inputs
self._logger.info(f"🎉 Finished evolving {len(responses)} instructions!")
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input("")
diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py
index 70edbd564a..85db623d94 100644
--- a/src/distilabel/steps/tasks/generate_embeddings.py
+++ b/src/distilabel/steps/tasks/generate_embeddings.py
@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List
+from typing import TYPE_CHECKING, Any, Dict
+from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.steps.base import Step, StepInput
from distilabel.utils.chat import is_openai_format
if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class GenerateEmbeddings(Step):
@@ -49,7 +50,6 @@ class GenerateEmbeddings(Step):
- [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685)
Examples:
-
Rank LLM candidates:
```python
@@ -76,7 +76,6 @@ class GenerateEmbeddings(Step):
```
Citations:
-
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
@@ -99,13 +98,13 @@ def load(self) -> None:
self.llm.load()
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The inputs for the task is a `text` column containing either a string or a
list of dictionaries in OpenAI chat-like format."""
return ["text"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The outputs for the task is an `embedding` column containing the embedding of
the `text` input."""
return ["embedding", "model_name"]
@@ -130,9 +129,10 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
if is_openai_format(text):
return text
- raise ValueError(
+ raise DistilabelUserError(
f"Couldn't format input for step {self.name}. The `text` input column has to"
- " be a string or a list of dictionaries in OpenAI chat-like format."
+ " be a string or a list of dictionaries in OpenAI chat-like format.",
+ page="components-gallery/tasks/generateembeddings/",
)
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py
index 27fd813cfd..02a0657339 100644
--- a/src/distilabel/steps/tasks/genstruct.py
+++ b/src/distilabel/steps/tasks/genstruct.py
@@ -69,7 +69,6 @@ class Genstruct(Task):
- [Ada-Instruct: Adapting Instruction Generators for Complex Reasoning](https://arxiv.org/abs/2310.04484)
Examples:
-
Generate instructions from raw documents using the title and content:
```python
@@ -105,7 +104,6 @@ class Genstruct(Task):
```
Citations:
-
```
@misc{cui2023adainstructadaptinginstructiongenerators,
title={Ada-Instruct: Adapting Instruction Generators for Complex Reasoning},
diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py
index 64465c9d61..d806e3aded 100644
--- a/src/distilabel/steps/tasks/improving_text_embeddings.py
+++ b/src/distilabel/steps/tasks/improving_text_embeddings.py
@@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.resources as importlib_resources
import random
import re
-import sys
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union
-if sys.version_info < (3, 9):
- import importlib_resources
-else:
- import importlib.resources as importlib_resources
-
from jinja2 import Template
from pydantic import Field, PrivateAttr
from typing_extensions import override
@@ -232,6 +227,10 @@ def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore
)
yield task_outputs, True
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.prompt
+
# IMPLEMENTED TASKS
class EmbeddingTaskGenerator(GeneratorTask):
@@ -256,11 +255,15 @@ class EmbeddingTaskGenerator(GeneratorTask):
with one row only containing a list with around 20 tasks; otherwise, if set to `True`, it
will return a `distilabel.Distiset` with around 20 rows, each containing one task.
+ Output columns:
+ - tasks (`List[str]`): the list of tasks generated by the `LLM`.
+ - task (`str`): the task generated by the `LLM` if `flatten_tasks=True`.
+ - model_name (`str`): the name of the model used to generate the tasks.
+
References:
- [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
Examples:
-
Generate embedding tasks for text retrieval:
```python
@@ -280,7 +283,6 @@ class EmbeddingTaskGenerator(GeneratorTask):
```
Citations:
-
```
@misc{wang2024improvingtextembeddingslarge,
title={Improving Text Embeddings with Large Language Models},
@@ -303,6 +305,7 @@ class EmbeddingTaskGenerator(GeneratorTask):
flatten_tasks: bool = False
_template: Union[Template, None] = PrivateAttr(...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
@@ -398,6 +401,10 @@ def format_output(
pass
return {"tasks": output}
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.prompt
+
class GenerateTextRetrievalData(_EmbeddingDataGeneration):
"""Generate text retrieval data with an `LLM` to later on train an embedding model.
@@ -427,11 +434,19 @@ class GenerateTextRetrievalData(_EmbeddingDataGeneration):
Defaults to `None`, meaning that it will be randomly sampled.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
+ Input columns:
+ - task (`str`): The task description to be used in the generation.
+
+ Output columns:
+ - user_query (`str`): the user query generated by the `LLM`.
+ - positive_document (`str`): the positive document generated by the `LLM`.
+ - hard_negative_document (`str`): the hard negative document generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the text retrieval data.
+
References:
- [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
Examples:
-
Generate synthetic text retrieval data for training embedding models:
```python
@@ -475,6 +490,7 @@ class GenerateTextRetrievalData(_EmbeddingDataGeneration):
num_words: Optional[Literal[50, 100, 200, 300, 400, 500]] = None
_template_name: str = PrivateAttr(default="text-retrieval")
+ _can_be_used_with_offline_batch_generation = True
def format_input(self, input: Dict[str, Any]) -> ChatType:
"""Method to format the input based on the `task` and the provided attributes, or just
@@ -541,11 +557,19 @@ class GenerateShortTextMatchingData(_EmbeddingDataGeneration):
seed: The random seed to be set in case there's any sampling within the `format_input` method.
Note that in this task the `seed` has no effect since there are no sampling params.
+ Input columns:
+ - task (`str`): The task description to be used in the generation.
+
+ Output columns:
+ - input (`str`): the input generated by the `LLM`.
+ - positive_document (`str`): the positive document generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the short text matching
+ data.
+
References:
- [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
Examples:
-
Generate synthetic short text matching data for training embedding models:
```python
@@ -574,6 +598,7 @@ class GenerateShortTextMatchingData(_EmbeddingDataGeneration):
)
_template_name: str = PrivateAttr(default="short-text-matching")
+ _can_be_used_with_offline_batch_generation = True
def format_input(self, input: Dict[str, Any]) -> ChatType:
"""Method to format the input based on the `task` and the provided attributes, or just
@@ -622,11 +647,19 @@ class GenerateLongTextMatchingData(_EmbeddingDataGeneration):
seed: The random seed to be set in case there's any sampling within the `format_input` method.
Note that in this task the `seed` has no effect since there are no sampling params.
+ Input columns:
+ - task (`str`): The task description to be used in the generation.
+
+ Output columns:
+ - input (`str`): the input generated by the `LLM`.
+ - positive_document (`str`): the positive document generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the long text matching
+ data.
+
References:
- [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
Examples:
-
Generate synthetic long text matching data for training embedding models:
```python
@@ -655,6 +688,7 @@ class GenerateLongTextMatchingData(_EmbeddingDataGeneration):
)
_template_name: str = PrivateAttr(default="long-text-matching")
+ _can_be_used_with_offline_batch_generation = True
def format_input(self, input: Dict[str, Any]) -> ChatType:
"""Method to format the input based on the `task` and the provided attributes, or just
@@ -706,11 +740,20 @@ class GenerateTextClassificationData(_EmbeddingDataGeneration):
or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
+ Input columns:
+ - task (`str`): The task description to be used in the generation.
+
+ Output columns:
+ - input_text (`str`): the input text generated by the `LLM`.
+ - label (`str`): the label generated by the `LLM`.
+ - misleading_label (`str`): the misleading label generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the text classification
+ data.
+
References:
- [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
Examples:
-
Generate synthetic text classification data for training embedding models:
```python
@@ -746,6 +789,7 @@ class GenerateTextClassificationData(_EmbeddingDataGeneration):
] = None
_template_name: str = PrivateAttr(default="text-classification")
+ _can_be_used_with_offline_batch_generation = True
def format_input(self, input: Dict[str, Any]) -> ChatType:
"""Method to format the input based on the `task` and the provided attributes, or just
@@ -802,8 +846,13 @@ class MonolingualTripletGenerator(_EmbeddingDataGenerator):
Defaults to `None`, meaning that it will be randomly sampled.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
- Examples:
+ Output columns:
+ - S1 (`str`): the first sentence generated by the `LLM`.
+ - S2 (`str`): the second sentence generated by the `LLM`.
+ - S3 (`str`): the third sentence generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the monolingual triplets.
+ Examples:
Generate monolingual triplets for training embedding models:
```python
@@ -837,6 +886,7 @@ class MonolingualTripletGenerator(_EmbeddingDataGenerator):
low_score: Optional[Literal["2.5", "3", "3.5"]] = None
_template_name: str = PrivateAttr(default="monolingual-triplet")
+ _can_be_used_with_offline_batch_generation = True
@property
def prompt(self) -> ChatType:
@@ -887,8 +937,14 @@ class BitextRetrievalGenerator(_EmbeddingDataGenerator):
Defaults to `None`, meaning that it will be randomly sampled.
seed: The random seed to be set in case there's any sampling within the `format_input` method.
- Examples:
+ Output columns:
+ - S1 (`str`): the first sentence generated by the `LLM`.
+ - S2 (`str`): the second sentence generated by the `LLM`.
+ - S3 (`str`): the third sentence generated by the `LLM`.
+ - model_name (`str`): the name of the model used to generate the bitext retrieval
+ data.
+ Examples:
Generate bitext retrieval data for training embedding models:
```python
@@ -927,6 +983,7 @@ class BitextRetrievalGenerator(_EmbeddingDataGenerator):
low_score: Optional[Literal["2.5", "3", "3.5"]] = None
_template_name: str = PrivateAttr(default="bitext-retrieval")
+ _can_be_used_with_offline_batch_generation = True
@property
def prompt(self) -> ChatType:
diff --git a/src/distilabel/steps/tasks/instruction_backtranslation.py b/src/distilabel/steps/tasks/instruction_backtranslation.py
index 3833333192..a0420ef8f3 100644
--- a/src/distilabel/steps/tasks/instruction_backtranslation.py
+++ b/src/distilabel/steps/tasks/instruction_backtranslation.py
@@ -50,8 +50,43 @@ class InstructionBacktranslation(Task):
References:
- [`Self-Alignment with Instruction Backtranslation`](https://arxiv.org/abs/2308.06259)
- Citations:
+ Examples:
+ Generate a score and reason for a given instruction and generation:
+
+ ```python
+ from distilabel.steps.tasks import InstructionBacktranslation
+
+ instruction_backtranslation = InstructionBacktranslation(
+ name="instruction_backtranslation",
+ llm=llm,
+ input_batch_size=10,
+ output_mappings={"model_name": "scoring_model"},
+ )
+ instruction_backtranslation.load()
+
+ result = next(
+ instruction_backtranslation.process(
+ [
+ {
+ "instruction": "How much is 2+2?",
+ "generation": "4",
+ }
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # "instruction": "How much is 2+2?",
+ # "generation": "4",
+ # "score": 3,
+ # "reason": "Reason for the generation.",
+ # "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct",
+ # }
+ # ]
+ ```
+ Citations:
```
@misc{li2024selfalignmentinstructionbacktranslation,
title={Self-Alignment with Instruction Backtranslation},
@@ -66,6 +101,7 @@ class InstructionBacktranslation(Task):
"""
_template: Optional["Template"] = PrivateAttr(default=...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py
index 3a611f73a2..a137d931dd 100644
--- a/src/distilabel/steps/tasks/magpie/base.py
+++ b/src/distilabel/steps/tasks/magpie/base.py
@@ -13,10 +13,12 @@
# limitations under the License.
import random
+from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
-from pydantic import Field, PositiveInt
+from pydantic import Field, PositiveInt, field_validator
+from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import (
@@ -28,7 +30,7 @@
if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
MAGPIE_MULTI_TURN_SYSTEM_PROMPT = (
"You are a helpful Al assistant. The user will engage in a multi−round conversation"
@@ -45,7 +47,6 @@ class MagpieBase(RuntimeParametersMixin):
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
Citations:
-
```
@misc{xu2024magpiealignmentdatasynthesis,
title={Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing},
@@ -78,15 +79,44 @@ class MagpieBase(RuntimeParametersMixin):
description="Whether to generate only the instruction. If this argument"
" is `True`, then `n_turns` will be ignored.",
)
- system_prompt: Optional[RuntimeParameter[Union[List[str], str]]] = Field(
+ system_prompt: Optional[
+ RuntimeParameter[
+ Union[List[str], Dict[str, str], Dict[str, Tuple[str, float]], str]
+ ]
+ ] = Field(
default=None,
- description="An optional system prompt or list of system prompts that can be used"
- " to steer the LLM to generate content of certain topic, guide the style, etc.",
+ description="An optional system prompt, or a list of system prompts from which a"
+ " random one will be chosen, or a dictionary of system prompts from which a random"
+ " one will be choosen, or a dictionary of system prompts with their probability of"
+ " being chosen. The random system prompt will be chosen per input/output batch."
+ " This system prompt can be used to guide the generation of the instruct LLM and"
+ " steer it to generate instructions of a certain topic.",
)
+ @field_validator("system_prompt", mode="after")
+ @classmethod
+ def system_prompts_weights_validator(
+ cls,
+ system_prompts: Union[
+ List[str], Dict[str, str], Dict[str, Tuple[str, float]], str
+ ],
+ ) -> Union[List[str], Dict[str, str], Dict[str, Tuple[str, float]], str]:
+ """Validates that the sum of the weights of the system prompts is equal to 1.0."""
+ if isinstance(system_prompts, dict):
+ system_prompts_values = list(system_prompts.values())
+ if isinstance(system_prompts_values[0], tuple):
+ weights_sum = sum(weight for _, weight in system_prompts_values) # type: ignore
+ if weights_sum != 1.0:
+ raise DistilabelUserError(
+ "If `system_prompts` attribute is a dictionary containing tuples with"
+ " the system prompts and their probability of being chosen, then the"
+ " sum of the weights must be equal to 1.0."
+ )
+ return system_prompts
+
def _prepare_inputs_for_instruction_generation(
self, inputs: List[Dict[str, Any]]
- ) -> List["ChatType"]:
+ ) -> Tuple[List["ChatType"], List[str]]:
"""Prepares the inputs adding the system (if required) prompt provided in each row,
or if the conversations to generate have more than one turn, then adding the system
prompt for multi-turn conversation from the paper.
@@ -95,9 +125,10 @@ def _prepare_inputs_for_instruction_generation(
inputs: the inputs to prepare.
Returns:
- The prepared inputs.
+ The prepared inputs and the system prompt keys used for each input.
"""
prepared_inputs = []
+ system_prompt_keys = []
for input in inputs:
conversation = []
if "system_prompt" in input:
@@ -106,7 +137,20 @@ def _prepare_inputs_for_instruction_generation(
)
elif self.system_prompt is not None:
if isinstance(self.system_prompt, list):
- system_prompt = random.choice(self.system_prompt)
+ system_prompt = random.choices(self.system_prompt, k=1)[0]
+ elif isinstance(self.system_prompt, dict):
+ system_prompts_keys = list(self.system_prompt.keys())
+ system_prompts_values = list(self.system_prompt.values())
+ weights: Union[List[float], None] = None
+ if isinstance(system_prompts_values[0], tuple):
+ weights = [weight for _, weight in system_prompts_values] # type: ignore
+ system_prompt_key = random.choices(
+ system_prompts_keys, weights, k=1
+ )[0]
+ system_prompt_keys.append(system_prompt_key)
+ system_prompt = self.system_prompt[system_prompt_key]
+ if isinstance(system_prompt, tuple):
+ system_prompt = system_prompt[0]
else:
system_prompt = self.system_prompt
conversation.append({"role": "system", "content": system_prompt})
@@ -117,7 +161,7 @@ def _prepare_inputs_for_instruction_generation(
prepared_inputs.append(conversation)
- return prepared_inputs
+ return prepared_inputs, system_prompt_keys
def _append_messages_to_conversations(
self, role: str, messages: List[str], conversations: List["ChatType"]
@@ -140,16 +184,26 @@ def _append_messages_to_conversations(
def _generate_instruction(
self, inputs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
- prepared_inputs = self._prepare_inputs_for_instruction_generation(inputs)
+ prepared_inputs, system_prompt_keys = (
+ self._prepare_inputs_for_instruction_generation(inputs)
+ )
outputs = self.llm.generate(
inputs=prepared_inputs,
num_generations=1,
**self.llm.generation_kwargs, # type: ignore
)
- return [{"instruction": output[0]} for output in outputs]
+ rows = []
+ for output, system_prompt_key in zip_longest(
+ outputs, system_prompt_keys, fillvalue=None
+ ):
+ row = {"instruction": output[0]} # type: ignore
+ if system_prompt_key is not None:
+ row["system_prompt_key"] = system_prompt_key
+ rows.append(row)
+ return rows
def _prepare_conversation_outputs(
- self, conversations: List["ChatType"]
+ self, conversations: List["ChatType"], system_prompt_keys: List[str]
) -> List[Dict[str, Any]]:
"""Prepare the output conversation removing the system prompt if necessary. If
`n_turns==1`, then it will return a dictionary with "instruction" and "response"
@@ -157,24 +211,36 @@ def _prepare_conversation_outputs(
Args:
conversations: the list of generated conversations.
+ system_prompt_keys: the list of system prompt keys used to generate the conversations.
Returns:
A list of dictionaries containing a "conversation" key or "instruction" and
"responses" key.
"""
outputs = []
- for conversation in conversations:
+ for conversation, system_prompt_key in zip_longest(
+ conversations, system_prompt_keys, fillvalue=None
+ ):
+ assert conversation is not None
+ # Something went wrong with the `LLM` and it didn't generate any message
+ if len(conversation) == 0:
+ if self.n_turns == 1:
+ outputs.append({"instruction": None, "response": None})
+ else:
+ outputs.append({"conversation": []})
+ continue
if not self.include_system_prompt and conversation[0]["role"] == "system":
conversation.pop(0)
if self.n_turns == 1 and len(conversation) == 2:
- outputs.append(
- {
- "instruction": conversation[0]["content"],
- "response": conversation[1]["content"],
- }
- )
+ output: Dict[str, Any] = {
+ "instruction": conversation[0]["content"],
+ "response": conversation[1]["content"],
+ }
else:
- outputs.append({"conversation": conversation})
+ output = {"conversation": conversation}
+ if system_prompt_key is not None:
+ output["system_prompt_key"] = system_prompt_key
+ outputs.append(output)
return outputs
def _generate_conversation_turn(
@@ -206,7 +272,7 @@ def _generate_conversation_turn(
def _generate_multi_turn_conversation(
self, inputs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
- conversations: List["ChatType"] = (
+ conversations, system_prompt_keys = (
self._prepare_inputs_for_instruction_generation(inputs)
)
# Keep track of the active conversations, as it could happen that for some conversation
@@ -235,7 +301,7 @@ def _generate_multi_turn_conversation(
active_indices=active_indices,
)
- return self._prepare_conversation_outputs(conversations)
+ return self._prepare_conversation_outputs(conversations, system_prompt_keys)
def _generate_with_pre_query_template(
self, inputs: List[Dict[str, Any]]
@@ -284,12 +350,12 @@ class Magpie(Task, MagpieBase):
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- system_prompt: an optional system prompt or list of system prompts that can
- be used to steer the LLM to generate content of certain topic, guide the style,
- etc. If it's a list of system prompts, then a random system prompt will be chosen
- per input/output batch. If the provided inputs contains a `system_prompt` column,
- then this runtime parameter will be ignored and the one from the column will
- be used. Defaults to `None`.
+ system_prompt: an optional system prompt, or a list of system prompts from which
+ a random one will be chosen, or a dictionary of system prompts from which a
+ random one will be choosen, or a dictionary of system prompts with their probability
+ of being chosen. The random system prompt will be chosen per input/output batch.
+ This system prompt can be used to guide the generation of the instruct LLM and
+ steer it to generate instructions of a certain topic. Defaults to `None`.
Runtime parameters:
- `n_turns`: the number of turns that the generated conversation will have. Defaults
@@ -306,6 +372,12 @@ class Magpie(Task, MagpieBase):
per input/output batch. If the provided inputs contains a `system_prompt` column,
then this runtime parameter will be ignored and the one from the column will
be used. Defaults to `None`.
+ - `system_prompt`: an optional system prompt, or a list of system prompts from which
+ a random one will be chosen, or a dictionary of system prompts from which a
+ random one will be choosen, or a dictionary of system prompts with their probability
+ of being chosen. The random system prompt will be chosen per input/output batch.
+ This system prompt can be used to guide the generation of the instruct LLM and
+ steer it to generate instructions of a certain topic.
Input columns:
- system_prompt (`str`, optional): an optional system prompt that can be provided
@@ -317,6 +389,8 @@ class Magpie(Task, MagpieBase):
items with a role and a message. Only if `only_instruction=False`.
- instruction (`str`): the generated instructions if `only_instruction=True` or `n_turns==1`.
- response (`str`): the generated response if `n_turns==1`.
+ - system_prompt_key (`str`, optional): the key of the system prompt used to generate
+ the conversation or instruction. Only if `system_prompt` is a dictionary.
- model_name (`str`): The model name used to generate the `conversation` or `instruction`.
Categories:
@@ -327,7 +401,6 @@ class Magpie(Task, MagpieBase):
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
Examples:
-
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
@@ -440,29 +513,40 @@ def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
if not isinstance(self.llm, MagpieChatTemplateMixin):
- raise ValueError(
+ raise DistilabelUserError(
f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`."
- f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin."
+ f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin.",
+ page="components-gallery/tasks/magpie/",
)
self.llm.use_magpie_template = True
@property
- def inputs(self) -> List[str]:
- return []
+ def inputs(self) -> "StepColumns":
+ return {"system_prompt": False}
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""Does nothing."""
return []
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""Either a multi-turn conversation or the instruction generated."""
+ outputs = []
+
if self.only_instruction:
- return ["instruction", "model_name"]
- if self.n_turns == 1:
- return ["instruction", "response", "model_name"]
- return ["conversation", "model_name"]
+ outputs.append("instruction")
+ elif self.n_turns == 1:
+ outputs.extend(["instruction", "response"])
+ else:
+ outputs.append("conversation")
+
+ if isinstance(self.system_prompt, dict):
+ outputs.append("system_prompt_key")
+
+ outputs.append("model_name")
+
+ return outputs
def format_output(
self,
diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py
index 33a910ef34..c1e413d32c 100644
--- a/src/distilabel/steps/tasks/magpie/generator.py
+++ b/src/distilabel/steps/tasks/magpie/generator.py
@@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any, Dict, Union
from pydantic import Field
+from typing_extensions import override
+from distilabel.errors import DistilabelUserError
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.steps.tasks.magpie.base import MagpieBase
if TYPE_CHECKING:
- from distilabel.steps.typing import GeneratorStepOutput
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import GeneratorStepOutput, StepColumns
class MagpieGenerator(GeneratorTask, MagpieBase):
@@ -49,12 +52,12 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- system_prompt: an optional system prompt or list of system prompts that can
- be used to steer the LLM to generate content of certain topic, guide the style,
- etc. If it's a list of system prompts, then a random system prompt will be chosen
- per input/output batch. If the provided inputs contains a `system_prompt` column,
- then this runtime parameter will be ignored and the one from the column will
- be used. Defaults to `None`.
+ system_prompt: an optional system prompt, or a list of system prompts from which
+ a random one will be chosen, or a dictionary of system prompts from which a
+ random one will be choosen, or a dictionary of system prompts with their probability
+ of being chosen. The random system prompt will be chosen per input/output batch.
+ This system prompt can be used to guide the generation of the instruct LLM and
+ steer it to generate instructions of a certain topic. Defaults to `None`.
num_rows: the number of rows to be generated.
Runtime parameters:
@@ -66,12 +69,12 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
conversation. Defaults to `False`.
- `only_instruction`: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- - `system_prompt`: an optional system prompt or list of system prompts that can
- be used to steer the LLM to generate content of certain topic, guide the style,
- etc. If it's a list of system prompts, then a random system prompt will be chosen
- per input/output batch. If the provided inputs contains a `system_prompt` column,
- then this runtime parameter will be ignored and the one from the column will
- be used. Defaults to `None`.
+ - `system_prompt`: an optional system prompt, or a list of system prompts from which
+ a random one will be chosen, or a dictionary of system prompts from which a
+ random one will be choosen, or a dictionary of system prompts with their probability
+ of being chosen. The random system prompt will be chosen per input/output batch.
+ This system prompt can be used to guide the generation of the instruct LLM and
+ steer it to generate instructions of a certain topic.
- `num_rows`: the number of rows to be generated.
Output columns:
@@ -79,6 +82,8 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
items with a role and a message.
- instruction (`str`): the generated instructions if `only_instruction=True`.
- response (`str`): the generated response if `n_turns==1`.
+ - system_prompt_key (`str`, optional): the key of the system prompt used to generate
+ the conversation or instruction. Only if `system_prompt` is a dictionary.
- model_name (`str`): The model name used to generate the `conversation` or `instruction`.
Categories:
@@ -90,7 +95,6 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
Examples:
-
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
@@ -203,8 +207,35 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
# )
```
- Citations:
+ Generating with system prompts with probabilities:
+
+ ```python
+ from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.steps.tasks import MagpieGenerator
+ magpie = MagpieGenerator(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-8B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
+ magpie_pre_query_template="llama3",
+ generation_kwargs={
+ "temperature": 0.8,
+ "max_new_tokens": 256,
+ },
+ ),
+ n_turns=2,
+ system_prompt={
+ "math": ("You're an expert AI assistant.", 0.8),
+ "writing": ("You're an expert writing assistant.", 0.2),
+ },
+ )
+
+ magpie.load()
+
+ result = next(magpie.process())
+ ```
+
+ Citations:
```
@misc{xu2024magpiealignmentdatasynthesis,
title={Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing},
@@ -228,13 +259,33 @@ def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
if not isinstance(self.llm, MagpieChatTemplateMixin):
- raise ValueError(
+ raise DistilabelUserError(
f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`."
- f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin."
+ f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin.",
+ page="components-gallery/tasks/magpiegenerator/",
)
self.llm.use_magpie_template = True
+ @property
+ def outputs(self) -> "StepColumns":
+ """Either a multi-turn conversation or the instruction generated."""
+ outputs = []
+
+ if self.only_instruction:
+ outputs.append("instruction")
+ elif self.n_turns == 1:
+ outputs.extend(["instruction", "response"])
+ else:
+ outputs.append("conversation")
+
+ if isinstance(self.system_prompt, dict):
+ outputs.append("system_prompt_key")
+
+ outputs.append("model_name")
+
+ return outputs
+
def format_output(
self,
output: Union[str, None],
@@ -243,15 +294,6 @@ def format_output(
"""Does nothing."""
return {}
- @property
- def outputs(self) -> List[str]:
- """Either a multi-turn conversation or the instruction generated."""
- if self.only_instruction:
- return ["instruction", "model_name"]
- if self.n_turns == 1:
- return ["instruction", "response", "model_name"]
- return ["conversation", "model_name"]
-
def process(self, offset: int = 0) -> "GeneratorStepOutput":
"""Generates the desired number of instructions or conversations using Magpie.
@@ -272,3 +314,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput":
)
generated += rows_to_generate # type: ignore
yield (conversations, generated == self.num_rows)
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ return self._generate_with_pre_query_template(inputs=[{}])
diff --git a/src/distilabel/steps/tasks/pair_rm.py b/src/distilabel/steps/tasks/pair_rm.py
index d4e20c22bf..23262a533f 100644
--- a/src/distilabel/steps/tasks/pair_rm.py
+++ b/src/distilabel/steps/tasks/pair_rm.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional
import numpy as np
@@ -20,7 +20,7 @@
from distilabel.steps.tasks.base import Step
if TYPE_CHECKING:
- from distilabel.steps.typing import StepOutput
+ from distilabel.steps.typing import StepColumns, StepOutput
class PairRM(Step):
@@ -51,7 +51,6 @@ class PairRM(Step):
currently, and we will use a specific `LLM`.
Examples:
-
Rank LLM candidates:
```python
@@ -82,7 +81,6 @@ class PairRM(Step):
```
Citations:
-
```
@misc{jiang2023llmblenderensemblinglargelanguage,
title={LLM-Blender: Ensembling Large Language Models with Pairwise Ranking and Generative Fusion},
@@ -114,13 +112,13 @@ def load(self) -> None:
self._blender.loadranker(self.model)
@property
- def inputs(self) -> List[str]:
+ def inputs(self) -> "StepColumns":
"""The input columns correspond to the two required arguments from `Blender.rank`:
`inputs` and `candidates`."""
return ["input", "candidates"]
@property
- def outputs(self) -> List[str]:
+ def outputs(self) -> "StepColumns":
"""The outputs will include the `ranks` and the `ranked_candidates`."""
return ["ranks", "ranked_candidates", "model_name"]
diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py
index 8e1ecbaca9..27cd9622ea 100644
--- a/src/distilabel/steps/tasks/prometheus_eval.py
+++ b/src/distilabel/steps/tasks/prometheus_eval.py
@@ -26,6 +26,7 @@
from pydantic import Field, PrivateAttr, model_validator
from typing_extensions import Self
+from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
@@ -75,15 +76,11 @@ class PrometheusEval(Task):
"""Critique and rank the quality of generations from an `LLM` using Prometheus 2.0.
`PrometheusEval` is a task created for Prometheus 2.0, covering both the absolute and relative
- evaluations.
-
- - The absolute evaluation i.e. `mode="absolute"` is used to evaluate a single generation from
- an LLM for a given instruction.
- - The relative evaluation i.e. `mode="relative"` is used to evaluate two generations from an LLM
- for a given instruction.
-
- Both evaluations provide the possibility whether to use a reference answer to compare with or not
- via the `reference` attribute, and both are based on a score rubric that critiques the generation/s
+ evaluations. The absolute evaluation i.e. `mode="absolute"` is used to evaluate a single generation from
+ an LLM for a given instruction. The relative evaluation i.e. `mode="relative"` is used to evaluate two generations from an LLM
+ for a given instruction.
+ Both evaluations provide the possibility of using a reference answer to compare with or withoug
+ the `reference` attribute, and both are based on a score rubric that critiques the generation/s
based on the following default aspects: `helpfulness`, `harmlessness`, `honesty`, `factual-validity`,
and `reasoning`, that can be overridden via `rubrics`, and the selected rubric is set via the attribute
`rubric`.
@@ -137,8 +134,7 @@ class PrometheusEval(Task):
- [prometheus-eval: Evaluate your LLM's response with Prometheus 💯](https://github.com/prometheus-eval/prometheus-eval)
Examples:
-
- Critique and evaluate LLM generation quality using Prometheus 2.0:
+ Critique and evaluate LLM generation quality using Prometheus 2_0:
```python
from distilabel.steps.tasks import PrometheusEval
@@ -148,7 +144,7 @@ class PrometheusEval(Task):
prometheus = PrometheusEval(
llm=vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
- chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
),
mode="absolute",
rubric="factual-validity"
@@ -185,7 +181,7 @@ class PrometheusEval(Task):
prometheus = PrometheusEval(
llm=vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
- chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
),
mode="relative",
rubric="honesty"
@@ -222,12 +218,12 @@ class PrometheusEval(Task):
prometheus = PrometheusEval(
llm=vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
- chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
),
mode="absolute",
rubric="custom",
rubrics={
- "custom": "[A]\nScore 1: A\nScore 2: B\nScore 3: C\nScore 4: D\nScore 5: E"
+ "custom": "[A]\\nScore 1: A\\nScore 2: B\\nScore 3: C\\nScore 4: D\\nScore 5: E"
}
)
@@ -262,7 +258,7 @@ class PrometheusEval(Task):
prometheus = PrometheusEval(
llm=vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
- chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
),
mode="absolute",
rubric="helpfulness",
@@ -296,7 +292,6 @@ class PrometheusEval(Task):
```
Citations:
-
```
@misc{kim2024prometheus2opensource,
title={Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models},
@@ -320,8 +315,9 @@ class PrometheusEval(Task):
@model_validator(mode="after")
def validate_rubric_and_rubrics(self) -> Self:
if not isinstance(self.rubrics, dict) or len(self.rubrics) < 1:
- raise ValueError(
- "Provided `rubrics` must be a Python dictionary with string keys and string values."
+ raise DistilabelUserError(
+ "Provided `rubrics` must be a Python dictionary with string keys and string values.",
+ page="components-gallery/tasks/prometheuseval/",
)
def rubric_matches_pattern(rubric: str) -> bool:
@@ -330,17 +326,19 @@ def rubric_matches_pattern(rubric: str) -> bool:
return bool(re.match(pattern, rubric, re.MULTILINE))
if not all(rubric_matches_pattern(value) for value in self.rubrics.values()):
- raise ValueError(
+ raise DistilabelUserError(
"Provided rubrics should match the format of the default rubrics, which"
" is as follows: `[]\nScore 1: \nScore 2: \n"
"Score 3: \nScore 4: \nScore 5: `; replacing"
" `` and `` with the actual criteria and description"
- " for each or the scores, respectively."
+ " for each or the scores, respectively.",
+ page="components-gallery/tasks/prometheuseval/",
)
if self.rubric not in self.rubrics:
- raise ValueError(
- f"Provided rubric '{self.rubric}' is not among the available rubrics: {', '.join(self.rubrics.keys())}."
+ raise DistilabelUserError(
+ f"Provided rubric '{self.rubric}' is not among the available rubrics: {', '.join(self.rubrics.keys())}.",
+ page="components-gallery/tasks/prometheuseval/",
)
return self
@@ -393,9 +391,10 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
if self.mode == "absolute":
if not isinstance(input["generation"], str):
- raise ValueError(
+ raise DistilabelUserError(
f"Provided `generation` is of type {type(input['generation'])} but a string"
" should be provided instead.",
+ page="components-gallery/tasks/prometheuseval/",
)
template_kwargs["generation"] = input["generation"]
@@ -412,8 +411,9 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
)
or len(input["generations"]) != 2
):
- raise ValueError(
- f"Provided `generations` is of type {type(input['generations'])} but a list of strings with length 2 should be provided instead."
+ raise DistilabelUserError(
+ f"Provided `generations` is of type {type(input['generations'])} but a list of strings with length 2 should be provided instead.",
+ page="components-gallery/tasks/prometheuseval/",
)
template_kwargs["generations"] = input["generations"]
diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py
index 5b905fd097..604f2a0276 100644
--- a/src/distilabel/steps/tasks/quality_scorer.py
+++ b/src/distilabel/steps/tasks/quality_scorer.py
@@ -22,8 +22,10 @@
from typing import Any, Dict, List, Union
+import orjson
from jinja2 import Template
from pydantic import PrivateAttr
+from typing_extensions import override
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
@@ -61,7 +63,6 @@ class QualityScorer(Task):
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
Examples:
-
Evaluate the quality of your instructions:
```python
@@ -97,8 +98,41 @@ class QualityScorer(Task):
]
```
- Citations:
+ Generate structured output with default schema:
+ ```python
+ from distilabel.steps.tasks import QualityScorer
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ scorer = QualityScorer(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ use_default_structured_output=True
+ )
+
+ scorer.load()
+
+ result = next(
+ scorer.process(
+ [
+ {
+ "instruction": "instruction",
+ "responses": ["good response", "weird response", "bad response"]
+ }
+ ]
+ )
+ )
+
+ # result
+ [{'instruction': 'instruction',
+ 'responses': ['good response', 'weird response', 'bad response'],
+ 'scores': [1, 2, 3],
+ 'distilabel_metadata': {'raw_output_quality_scorer_0': '{ "scores": [1, 2, 3] }'},
+ 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Citations:
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
@@ -113,6 +147,7 @@ class QualityScorer(Task):
"""
_template: Union[Template, None] = PrivateAttr(...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
@@ -166,6 +201,9 @@ def format_output(
if output is None:
return {"scores": [None] * len(input["responses"])}
+ if self.use_default_structured_output:
+ return self._format_structured_output(output, input)
+
scores = []
score_lines = output.split("\n")
@@ -176,3 +214,62 @@ def format_output(
if i == len(input["responses"]) - 1:
break
return {"scores": scores}
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from pydantic import BaseModel
+ from typing import List
+
+ class SchemaQualityScorer(BaseModel):
+ scores: List[int]
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "properties": {
+ "scores": {
+ "items": {"type": "integer"},
+ "title": "Scores",
+ "type": "array",
+ }
+ },
+ "required": ["scores"],
+ "title": "SchemaQualityScorer",
+ "type": "object",
+ }
+
+ def _format_structured_output(
+ self, output: str, input: Dict[str, Any]
+ ) -> Dict[str, str]:
+ """Parses the structured response, which should correspond to a dictionary
+ with the scores, and a list with them.
+
+ Args:
+ output: The output from the `LLM`.
+
+ Returns:
+ Formatted output.
+ """
+ try:
+ return orjson.loads(output)
+ except orjson.JSONDecodeError:
+ return {"scores": [None] * len(input["responses"])}
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ {
+ "instruction": f"",
+ "responses": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py
index 27bcecc0bf..28ac346c39 100644
--- a/src/distilabel/steps/tasks/self_instruct.py
+++ b/src/distilabel/steps/tasks/self_instruct.py
@@ -62,7 +62,6 @@ class SelfInstruct(Task):
- [`Self-Instruct: Aligning Language Models with Self-Generated Instructions`](https://arxiv.org/abs/2212.10560)
Examples:
-
Generate instructions based on a given input:
```python
@@ -90,7 +89,6 @@ class SelfInstruct(Task):
```
Citations:
-
```
@misc{wang2023selfinstructaligninglanguagemodels,
title={Self-Instruct: Aligning Language Models with Self-Generated Instructions},
@@ -114,6 +112,7 @@ class SelfInstruct(Task):
application_description: str = "AI assistant"
_template: Union[Template, None] = PrivateAttr(...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py
index 666c89243f..f33a223c63 100644
--- a/src/distilabel/steps/tasks/sentence_transformers.py
+++ b/src/distilabel/steps/tasks/sentence_transformers.py
@@ -16,7 +16,9 @@
import sys
from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union
+import orjson
from jinja2 import Template
+from typing_extensions import override
from distilabel.steps.tasks.base import Task
@@ -31,7 +33,7 @@
GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"]
POSITIVE_NEGATIVE_PAIR_REGEX = re.compile(
- r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$",
+ r"\s*## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$",
re.DOTALL,
)
@@ -102,7 +104,6 @@ class GenerateSentencePair(Task):
- embedding
Examples:
-
Paraphrasing:
```python
@@ -232,6 +233,28 @@ class GenerateSentencePair(Task):
result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```
+ Generating structured data with default schema (**applies to every action**):
+
+ ```python
+ from distilabel.steps.tasks import GenerateSentencePair
+ from distilabel.llms import InferenceEndpointsLLM
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="query",
+ context="Argilla is an open-source data curation platform for LLMs.",
+ hard_negative=True,
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ input_batch_size=10,
+ use_default_structured_output=True
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
+ ```
"""
triplet: bool = False
@@ -320,6 +343,9 @@ def format_output(
if output is None:
return {"positive": None, "negative": None}
+ if self.use_default_structured_output:
+ return self._format_structured_output(output)
+
match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output)
if match is None:
formatted_output = {"positive": None}
@@ -331,9 +357,53 @@ def format_output(
if self.triplet:
return {
"positive": groups[0].strip(),
- "negative": groups[1].strip()
- if len(groups) > 1 and groups[1] is not None
- else None,
+ "negative": (
+ groups[1].strip()
+ if len(groups) > 1 and groups[1] is not None
+ else None
+ ),
}
return {"positive": groups[0].strip()}
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ if self.triplet:
+ return {
+ "properties": {
+ "positive": {"title": "Positive", "type": "string"},
+ "negative": {"title": "Negative", "type": "string"},
+ },
+ "required": ["positive", "negative"],
+ "title": "Schema",
+ "type": "object",
+ }
+ return {
+ "properties": {"positive": {"title": "Positive", "type": "string"}},
+ "required": ["positive"],
+ "title": "Schema",
+ "type": "object",
+ }
+
+ def _format_structured_output(self, output: str) -> Dict[str, str]:
+ """Parses the structured response, which should correspond to a dictionary
+ with either `positive`, or `positive` and `negative` keys.
+
+ Args:
+ output: The output from the `LLM`.
+
+ Returns:
+ Formatted output.
+ """
+ try:
+ return orjson.loads(output)
+ except orjson.JSONDecodeError:
+ if self.triplet:
+ return {"positive": None, "negative": None}
+ return {"positive": None}
diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py
index 240cd44698..81ee74bd85 100644
--- a/src/distilabel/steps/tasks/structured_generation.py
+++ b/src/distilabel/steps/tasks/structured_generation.py
@@ -15,6 +15,7 @@
import warnings
from typing import Any, Dict, List, Union
+from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import StructuredInput
@@ -47,7 +48,6 @@ class StructuredGeneration(Task):
- structured-generation
Examples:
-
Generate structured output from a JSON schema:
```python
@@ -69,8 +69,8 @@ class StructuredGeneration(Task):
{
"instruction": "Create an RPG character",
"structured_output": {
- "type": "json",
- "value": {
+ "format": "json",
+ "schema": {
"properties": {
"name": {
"title": "Name",
@@ -105,7 +105,7 @@ class StructuredGeneration(Task):
)
```
- Generate structured output from a regex pattern:
+ Generate structured output from a regex pattern (only works with LLMs that support regex, the providers using outlines):
```python
from distilabel.steps.tasks import StructuredGeneration
@@ -126,8 +126,8 @@ class StructuredGeneration(Task):
{
"instruction": "What's the weather like today in Seattle in Celsius degrees?",
"structured_output": {
- "type": "regex",
- "value": r"(\\d{1,2})°C"
+ "format": "regex",
+ "schema": r"(\\d{1,2})°C"
},
}
@@ -153,8 +153,9 @@ def format_input(self, input: Dict[str, Any]) -> StructuredInput:
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
if not isinstance(input["instruction"], str):
- raise ValueError(
- f"Input `instruction` must be a string. Got: {input['instruction']}."
+ raise DistilabelUserError(
+ f"Input `instruction` must be a string. Got: {input['instruction']}.",
+ page="components-gallery/tasks/structuredgeneration/",
)
messages = [{"role": "user", "content": input["instruction"]}]
diff --git a/src/distilabel/steps/tasks/structured_outputs/instructor.py b/src/distilabel/steps/tasks/structured_outputs/instructor.py
index 94ab1097e5..93b90d9916 100644
--- a/src/distilabel/steps/tasks/structured_outputs/instructor.py
+++ b/src/distilabel/steps/tasks/structured_outputs/instructor.py
@@ -24,6 +24,8 @@
get_args,
)
+from distilabel.errors import DistilabelUserError
+
if TYPE_CHECKING:
import instructor
from anthropic import AsyncAnthropic
@@ -115,8 +117,9 @@ def prepare_instructor(
mode = mode or default_mode
if mode.value not in [m.value for m in instructor.mode.Mode]:
- raise ValueError(
- f"Invalid mode '{mode}'. Must be one of {[m.value for m in instructor.mode.Mode]}"
+ raise DistilabelUserError(
+ f"Invalid mode '{mode}'. Must be one of {[m.value for m in instructor.mode.Mode]}",
+ page="sections/how_to_guides/advanced/structured_generation/#instructor",
)
patched_client: instructor.AsyncInstructor = builder(client, mode=mode)
diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py
index d726b5e4f5..62419d37b0 100644
--- a/src/distilabel/steps/tasks/structured_outputs/outlines.py
+++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py
@@ -14,6 +14,7 @@
import importlib
import importlib.util
+import inspect
import json
from typing import (
Any,
@@ -28,6 +29,7 @@
from pydantic import BaseModel
+from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict
from distilabel.steps.tasks.typing import StructuredOutputType
@@ -63,8 +65,9 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
return JSONLogitsProcessor, RegexLogitsProcessor
- raise ValueError(
- f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}"
+ raise DistilabelUserError(
+ f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
+ page="sections/how_to_guides/advanced/structured_generation/",
)
@@ -102,6 +105,13 @@ def prepare_guided_output(
format = structured_output.get("format")
schema = structured_output.get("schema")
+ # If schema not informed (may be forgotten), try infering it
+ if not format:
+ if isinstance(schema, dict) or inspect.isclass(schema):
+ format = "json"
+ elif isinstance(schema, str):
+ format = "regex"
+
if format == "json":
return {
"processor": json_processor(
@@ -115,4 +125,7 @@ def prepare_guided_output(
if format == "regex":
return {"processor": regex_processor(schema, llm)}
- raise ValueError(f"Invalid format '{format}'. Must be either 'json' or 'regex'.")
+ raise DistilabelUserError(
+ f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
+ page="sections/how_to_guides/advanced/structured_generation/",
+ )
diff --git a/src/distilabel/steps/tasks/templates/apigen/generator.jinja2 b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2
new file mode 100644
index 0000000000..cc92c725c3
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2
@@ -0,0 +1,10 @@
+Here are examples of queries and the corresponding answers for similar functions:
+{{ examples }}
+
+Note that the query could be interpreted as a combination of several independent requests.
+{{ parallel_queries }}
+Based on these examples, generate {{ number }} diverse query and answer pairs for the function `{{ func_name }}`.
+The detailed function description is the following:
+{{ func_desc }}
+{{ format_inst }}
+Now please generate {{ number }} diverse query and answer pairs following the above format.
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2 b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2
new file mode 100644
index 0000000000..8d94357e7e
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2
@@ -0,0 +1,13 @@
+Given Information:
+- All Available Functions:
+{{ func_desc }}
+- User Query: {{ query }}
+- Generated Function Calls: {{ func_call }}
+- Execution Results: {{ execution_result }}
+
+Note: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.
+
+The main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions.
+Provide your reasoning in the thought section and decide if the data passes (answer yes or no).
+If not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.
+{{ format_inst }}
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/argillalabeller.jinja2 b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
new file mode 100644
index 0000000000..d5afa75d27
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
@@ -0,0 +1,13 @@
+Please provide an answer to the question based on the input fields{% if examples %} and examples{% endif %}.
+{% if guidelines %}
+# Guidelines
+{{ guidelines }}
+{% endif %}
+# Input Fields
+{{ fields }}
+# Question
+{{ question }}
+{% if examples %}
+# Examples
+{{ examples }}
+{% endif %}
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/clair.jinja2 b/src/distilabel/steps/tasks/templates/clair.jinja2
new file mode 100644
index 0000000000..3815c6db83
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/clair.jinja2
@@ -0,0 +1,7 @@
+{task}: {{ task }}
+
+{student_solution}: {{ student_solution }}
+
+-----------------
+
+Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/urial.jinja2 b/src/distilabel/steps/tasks/templates/urial.jinja2
new file mode 100644
index 0000000000..09a45bcc58
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/urial.jinja2
@@ -0,0 +1,16 @@
+# Instruction
+
+Below is a list of conversations between a human and an AI assistant (you).
+Users place their queries under "# User:", and your responses are under "# Assistant:".
+You are a helpful, respectful, and honest assistant.
+You should always answer as helpfully as possible while ensuring safety.
+Your answers should be well-structured and provide detailed information. They should also have an engaging tone.
+Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.
+Your response must be socially responsible, and thus you can refuse to answer some controversial topics.
+
+{% for message in messages %}
+# {{ message.role | capitalize }}:
+
+{{ message.content }}
+{% endfor %}
+# Assistant:
diff --git a/src/distilabel/steps/tasks/text_classification.py b/src/distilabel/steps/tasks/text_classification.py
new file mode 100644
index 0000000000..5d04b3b2db
--- /dev/null
+++ b/src/distilabel/steps/tasks/text_classification.py
@@ -0,0 +1,378 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from textwrap import indent
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import orjson
+from jinja2 import Template
+from pydantic import BaseModel, Field, PositiveInt, PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+
+
+TEXT_CLASSIFICATION_TEMPLATE: str = """\
+# Instruction
+Please classify the {{ query_title.lower() }} by assigning the most appropriate labels.
+Do not explain your reasoning or provide any additional commentary.
+If the text is ambiguous or lacks sufficient information for classification, respond with "{{ default_label }}".
+{{ labels_message }}{{ context}}
+{{ available_labels }}
+{{ examples }}
+
+## {{ query_title }}
+```
+{{ text }}
+```
+
+## Output Format
+Now, please give me the labels in JSON format, do not include any other text in your response:
+```
+{
+ "labels": {{ labels_format }}
+}
+```
+""".rstrip()
+
+
+class TextClassification(Task):
+ r"""Classifies text into one or more categories or labels.
+
+ This task can be used for text classification problems, where the goal is to assign
+ one or multiple labels to a given text.
+ It uses structured generation as per the reference paper by default,
+ it can help to generate more concise labels. See section 4.1 in the reference.
+
+ Input columns:
+ - text (`str`): The reference text we want to obtain labels for.
+
+ Output columns:
+ - labels (`Union[str, List[str]]`): The label or list of labels for the text.
+ - model_name (`str`): The name of the model used to generate the label/s.
+
+ Categories:
+ - text-classification
+
+ References:
+ - [`Let Me Speak Freely? A Study on the Impact of Format Restrictions on Performance of Large Language Models`](https://arxiv.org/abs/2408.02442)
+
+ Attributes:
+ system_prompt: A prompt to display to the user before the task starts. Contains a default
+ message to make the model behave like a classifier specialist.
+ n: Number of labels to generate If only 1 is required, corresponds to a label
+ classification problem, if >1 it will intend return the "n" labels most representative
+ for the text. Defaults to 1.
+ context: Context to use when generating the labels. By default contains a generic message,
+ but can be used to customize the context for the task.
+ examples: List of examples to help the model understand the task, few shots.
+ available_labels: List of available labels to choose from when classifying the text, or
+ a dictionary with the labels and their descriptions.
+ default_label: Default label to use when the text is ambiguous or lacks sufficient information for
+ classification. Can be a list in case of multiple labels (n>1).
+
+ Examples:
+ Assigning a sentiment to a text:
+
+ ```python
+ from distilabel.steps.tasks import TextClassification
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ llm = InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ )
+
+ text_classification = TextClassification(
+ llm=llm,
+ context="You are an AI system specialized in assigning sentiment to movies.",
+ available_labels=["positive", "negative"],
+ )
+
+ text_classification.load()
+
+ result = next(
+ text_classification.process(
+ [{"text": "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."}]
+ )
+ )
+ # result
+ # [{'text': 'This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.',
+ # 'labels': 'positive',
+ # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": "positive"\n}',
+ # 'raw_input_text_classification_0': [{'role': 'system',
+ # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'},
+ # {'role': 'user',
+ # 'content': '# Instruction\nPlease classify the user query by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide the label that best describes the text.\nYou are an AI system specialized in assigning sentiment to movie the user queries.\n## Labeling the user input\nUse the available labels to classify the user query. Analyze the context of each label specifically:\navailable_labels = [\n "positive", # The text shows positive sentiment\n "negative", # The text shows negative sentiment\n]\n\n\n## User Query\n```\nThis was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": "label"\n}\n```'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Assigning predefined labels with specified descriptions:
+
+ ```python
+ from distilabel.steps.tasks import TextClassification
+
+ text_classification = TextClassification(
+ llm=llm,
+ n=1,
+ context="Determine the intent of the text.",
+ available_labels={
+ "complaint": "A statement expressing dissatisfaction or annoyance about a product, service, or experience. It's a negative expression of discontent, often with the intention of seeking a resolution or compensation.",
+ "inquiry": "A question or request for information about a product, service, or situation. It's a neutral or curious expression seeking clarification or details.",
+ "feedback": "A statement providing evaluation, opinion, or suggestion about a product, service, or experience. It can be positive, negative, or neutral, and is often intended to help improve or inform.",
+ "praise": "A statement expressing admiration, approval, or appreciation for a product, service, or experience. It's a positive expression of satisfaction or delight, often with the intention of encouraging or recommending."
+ },
+ query_title="Customer Query",
+ )
+
+ text_classification.load()
+
+ result = next(
+ text_classification.process(
+ [{"text": "Can you tell me more about your return policy?"}]
+ )
+ )
+ # result
+ # [{'text': 'Can you tell me more about your return policy?',
+ # 'labels': 'inquiry',
+ # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": "inquiry"\n}',
+ # 'raw_input_text_classification_0': [{'role': 'system',
+ # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'},
+ # {'role': 'user',
+ # 'content': '# Instruction\nPlease classify the customer query by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide the label that best describes the text.\nDetermine the intent of the text.\n## Labeling the user input\nUse the available labels to classify the user query. Analyze the context of each label specifically:\navailable_labels = [\n "complaint", # A statement expressing dissatisfaction or annoyance about a product, service, or experience. It\'s a negative expression of discontent, often with the intention of seeking a resolution or compensation.\n "inquiry", # A question or request for information about a product, service, or situation. It\'s a neutral or curious expression seeking clarification or details.\n "feedback", # A statement providing evaluation, opinion, or suggestion about a product, service, or experience. It can be positive, negative, or neutral, and is often intended to help improve or inform.\n "praise", # A statement expressing admiration, approval, or appreciation for a product, service, or experience. It\'s a positive expression of satisfaction or delight, often with the intention of encouraging or recommending.\n]\n\n\n## Customer Query\n```\nCan you tell me more about your return policy?\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": "label"\n}\n```'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Free multi label classification without predefined labels:
+
+ ```python
+ from distilabel.steps.tasks import TextClassification
+
+ text_classification = TextClassification(
+ llm=llm,
+ n=3,
+ context=(
+ "Describe the main themes, topics, or categories that could describe the "
+ "following type of persona."
+ ),
+ query_title="Example of Persona",
+ )
+
+ text_classification.load()
+
+ result = next(
+ text_classification.process(
+ [{"text": "A historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States."}]
+ )
+ )
+ # result
+ # [{'text': 'A historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States.',
+ # 'labels': ['Historical Researcher',
+ # 'Cultural Specialist',
+ # 'Ethnic Studies Expert'],
+ # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": ["Historical Researcher", "Cultural Specialist", "Ethnic Studies Expert"]\n}',
+ # 'raw_input_text_classification_0': [{'role': 'system',
+ # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'},
+ # {'role': 'user',
+ # 'content': '# Instruction\nPlease classify the example of persona by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide a list of 3 labels that best describe the text.\nDescribe the main themes, topics, or categories that could describe the following type of persona.\nUse clear, widely understood terms for labels.Avoid overly specific or obscure labels unless the text demands it.\n\n\n## Example of Persona\n```\nA historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States.\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": ["label_0", "label_1", "label_2"]\n}\n```'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+ """
+
+ system_prompt: Optional[str] = (
+ "You are an AI system specialized in generating labels to classify pieces of text. "
+ "Your sole purpose is to analyze the given text and provide appropriate classification labels."
+ )
+ n: PositiveInt = Field(
+ default=1,
+ description="Number of labels to generate. Defaults to 1.",
+ )
+ context: Optional[str] = Field(
+ default="Generate concise, relevant labels that accurately represent the text's main themes, topics, or categories.",
+ description="Context to use when generating the labels.",
+ )
+ examples: Optional[List[str]] = Field(
+ default=None,
+ description="List of examples to help the model understand the task, few shots.",
+ )
+ available_labels: Optional[Union[List[str], Dict[str, str]]] = Field(
+ default=None,
+ description=(
+ "List of available labels to choose from when classifying the text, or "
+ "a dictionary with the labels and their descriptions."
+ ),
+ )
+ default_label: Optional[Union[str, List[str]]] = Field(
+ default="Unclassified",
+ description=(
+ "Default label to use when the text is ambiguous or lacks sufficient information for "
+ "classification. Can be a list in case of multiple labels (n>1)."
+ ),
+ )
+ query_title: str = Field(
+ default="User Query",
+ description="Title of the query used to show the example/s to classify.",
+ )
+ use_default_structured_output: bool = True
+
+ _template: Optional[Template] = PrivateAttr(default=None)
+
+ def load(self) -> None:
+ super().load()
+ self._template = Template(TEXT_CLASSIFICATION_TEMPLATE)
+ self._labels_format: str = (
+ '"label"'
+ if self.n == 1
+ else "[" + ", ".join([f'"label_{i}"' for i in range(self.n)]) + "]"
+ )
+ self._labels_message: str = (
+ "Provide the label that best describes the text."
+ if self.n == 1
+ else f"Provide a list of {self.n} labels that best describe the text."
+ )
+ self._available_labels_message: str = self._get_available_labels_message()
+ self._examples: str = self._get_examples_message()
+
+ def _get_available_labels_message(self) -> str:
+ """Prepares the message to display depending on the available labels (if any),
+ and whether the labels have a specific context.
+ """
+ if self.available_labels is None:
+ return (
+ "Use clear, widely understood terms for labels."
+ "Avoid overly specific or obscure labels unless the text demands it."
+ )
+
+ msg = (
+ "## Labeling the user input\n"
+ "Use the available labels to classify the user query{label_context}:\n"
+ "available_labels = {available_labels}"
+ )
+ if isinstance(self.available_labels, list):
+ specific_msg = (
+ "[\n"
+ + indent(
+ "".join([f'"{label}",\n' for label in self.available_labels]),
+ prefix=" " * 4,
+ )
+ + "]"
+ )
+ return msg.format(label_context="", available_labels=specific_msg)
+
+ elif isinstance(self.available_labels, dict):
+ specific_msg = ""
+ for label, description in self.available_labels.items():
+ specific_msg += indent(
+ f'"{label}", # {description}' + "\n", prefix=" " * 4
+ )
+
+ specific_msg = "[\n" + specific_msg + "]"
+ return msg.format(
+ label_context=". Analyze the context of each label specifically",
+ available_labels=specific_msg,
+ )
+
+ def _get_examples_message(self) -> str:
+ """Prepares the message to display depending on the examples provided."""
+ if self.examples is None:
+ return ""
+
+ examples_msg = "\n".join([f"- {ex}" for ex in self.examples])
+
+ return (
+ "\n## Examples\n"
+ "Here are some examples to help you understand the task:\n"
+ f"{examples_msg}"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ """The input for the task is the `instruction`."""
+ return ["text"]
+
+ @property
+ def outputs(self) -> List[str]:
+ """The output for the task is the `generation` and the `model_name`."""
+ return ["labels", "model_name"]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType` assuming that the instruction
+ is the first interaction from the user within a conversation."""
+ messages = [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ context=f"\n{self.context}",
+ labels_message=self._labels_message,
+ available_labels=self._available_labels_message,
+ examples=self._examples,
+ default_label=self.default_label,
+ labels_format=self._labels_format,
+ query_title=self.query_title,
+ text=input["text"],
+ ),
+ },
+ ]
+ if self.system_prompt:
+ messages.insert(0, {"role": "system", "content": self.system_prompt})
+ return messages
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """The output is formatted as a dictionary with the `generation`. The `model_name`
+ will be automatically included within the `process` method of `Task`."""
+ return self._format_structured_output(output)
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ if self.n > 1:
+
+ class MultiLabelSchema(BaseModel):
+ labels: List[str]
+
+ return MultiLabelSchema.model_json_schema()
+
+ class SingleLabelSchema(BaseModel):
+ labels: str
+
+ return SingleLabelSchema.model_json_schema()
+
+ def _format_structured_output(
+ self, output: str
+ ) -> Dict[str, Union[str, List[str]]]:
+ """Parses the structured response, which should correspond to a dictionary
+ with the `labels`, and either a string or a list of strings with the labels.
+
+ Args:
+ output: The output from the `LLM`.
+
+ Returns:
+ Formatted output.
+ """
+ try:
+ return orjson.loads(output)
+ except orjson.JSONDecodeError:
+ if self.n > 1:
+ return {"labels": [None for _ in range(self.n)]}
+ return {"labels": None}
diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py
index f5c4659651..a8b2048e54 100644
--- a/src/distilabel/steps/tasks/text_generation.py
+++ b/src/distilabel/steps/tasks/text_generation.py
@@ -12,28 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import warnings
-from typing import Any, Dict, List, Union
+import re
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from jinja2 import Template
+from pydantic import Field, PrivateAttr
+
+from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
-from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.chat import is_openai_format
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
class TextGeneration(Task):
- """Simple text generation with an `LLM` given an instruction.
+ """Text generation with an `LLM` given a prompt.
- `TextGeneration` is a pre-defined task that defines the `instruction` as the input
- and `generation` as the output. This task is used to generate text based on the input
- instruction. The model_name is also returned as part of the output in order to enhance it.
+ `TextGeneration` is a pre-defined task that allows passing a custom prompt using the
+ Jinja2 syntax. By default, a `instruction` is expected in the inputs, but the using
+ `template` and `columns` attributes one can define a custom prompt and columns expected
+ from the text. This task should be good enough for tasks that don't need post-processing
+ of the responses generated by the LLM.
Attributes:
- use_system_prompt: Whether to use the system prompt in the generation. Defaults to `True`,
- which means that if the column `system_prompt` is defined within the input batch, then
- the `system_prompt` will be used, otherwise, it will be ignored.
+ system_prompt: The system prompt to use in the generation. If not provided, then
+ it will check if the input row has a column named `system_prompt` and use it.
+ If not, then no system prompt will be used. Defaults to `None`.
+ template: The template to use for the generation. It must follow the Jinja2 template
+ syntax. If not provided, it will assume the text passed is an instruction and
+ construct the appropriate template.
+ columns: A string with the column, or a list with columns expected in the template.
+ Take a look at the examples for more information. Defaults to `instruction`.
+ use_system_prompt: DEPRECATED. To be removed in 1.5.0. Whether to use the system
+ prompt in the generation. Defaults to `True`, which means that if the column
+ `system_prompt` is defined within the input batch, then the `system_prompt`
+ will be used, otherwise, it will be ignored.
Input columns:
- - instruction (`str`): The instruction to generate text from.
+ - dynamic (determined by `columns` attribute): By default will be set to `instruction`.
+ The columns can point both to a `str` or a `List[str]` to be used in the template.
Output columns:
- generation (`str`): The generated text.
@@ -42,8 +61,10 @@ class TextGeneration(Task):
Categories:
- text-generation
- Examples:
+ References:
+ - [Jinja2 Template Designer Documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/)
+ Examples:
Generate text from an instruction:
```python
@@ -53,7 +74,7 @@ class TextGeneration(Task):
# Consider this as a placeholder for your actual LLM.
text_gen = TextGeneration(
llm=InferenceEndpointsLLM(
- model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)
)
@@ -68,47 +89,201 @@ class TextGeneration(Task):
# [
# {
# 'instruction': 'your instruction',
- # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2',
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct',
# 'generation': 'generation',
# }
# ]
```
- """
- use_system_prompt: bool = True
+ Use a custom template to generate text:
- @property
- def inputs(self) -> List[str]:
- """The input for the task is the `instruction`."""
- return ["instruction"]
+ ```python
+ from distilabel.steps.tasks import TextGeneration
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
- def format_input(self, input: Dict[str, Any]) -> ChatType:
- """The input is formatted as a `ChatType` assuming that the instruction
- is the first interaction from the user within a conversation."""
+ CUSTOM_TEMPLATE = '''Document:
+ {{ document }}
+
+ Question: {{ question }}
+
+ Please provide a clear and concise answer to the question based on the information in the document and your general knowledge:
+ '''.rstrip()
+
+ text_gen = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ system_prompt="You are a helpful AI assistant. Your task is to answer the following question based on the provided document. If the answer is not explicitly stated in the document, use your knowledge to provide the most relevant and accurate answer possible. If you cannot answer the question based on the given information, state that clearly.",
+ template=CUSTOM_TEMPLATE,
+ columns=["document", "question"],
+ )
- if is_openai_format(input["instruction"]):
- raise ValueError(
- "Providing `instruction` formatted as an OpenAI chat / conversation is"
- " deprecated, you should use `ChatGeneration` with `messages` as input instead.",
+ text_gen.load()
+
+ result = next(
+ text_gen.process(
+ [
+ {
+ "document": "The Great Barrier Reef, located off the coast of Australia, is the world's largest coral reef system. It stretches over 2,300 kilometers and is home to a diverse array of marine life, including over 1,500 species of fish. However, in recent years, the reef has faced significant challenges due to climate change, with rising sea temperatures causing coral bleaching events.",
+ "question": "What is the main threat to the Great Barrier Reef mentioned in the document?"
+ }
+ ]
)
+ )
+ # result
+ # [
+ # {
+ # 'document': 'The Great Barrier Reef, located off the coast of Australia, is the world's largest coral reef system. It stretches over 2,300 kilometers and is home to a diverse array of marine life, including over 1,500 species of fish. However, in recent years, the reef has faced significant challenges due to climate change, with rising sea temperatures causing coral bleaching events.',
+ # 'question': 'What is the main threat to the Great Barrier Reef mentioned in the document?',
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct',
+ # 'generation': 'According to the document, the main threat to the Great Barrier Reef is climate change, specifically rising sea temperatures causing coral bleaching events.',
+ # }
+ # ]
+ ```
+
+ Few shot learning with different system prompts:
+
+ ```python
+ from distilabel.steps.tasks import TextGeneration
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
- if not isinstance(input["instruction"], str):
- raise ValueError(
- f"Input `instruction` must be a string. Got: {input['instruction']}."
+ CUSTOM_TEMPLATE = '''Generate a clear, single-sentence instruction based on the following examples:
+
+ {% for example in examples %}
+ Example {{ loop.index }}:
+ Instruction: {{ example }}
+
+ {% endfor %}
+ Now, generate a new instruction in a similar style:
+ '''.rstrip()
+
+ text_gen = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ template=CUSTOM_TEMPLATE,
+ columns="examples",
+ )
+
+ text_gen.load()
+
+ result = next(
+ text_gen.process(
+ [
+ {
+ "examples": ["This is an example", "Another relevant example"],
+ "system_prompt": "You are an AI assistant specialised in cybersecurity and computing in general, you make your point clear without any explanations."
+ }
+ ]
)
+ )
+ # result
+ # [
+ # {
+ # 'examples': ['This is an example', 'Another relevant example'],
+ # 'system_prompt': 'You are an AI assistant specialised in cybersecurity and computing in general, you make your point clear without any explanations.',
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct',
+ # 'generation': 'Disable the firewall on the router',
+ # }
+ # ]
+ ```
+ """
- messages = [{"role": "user", "content": input["instruction"]}]
- if self.use_system_prompt:
- if "system_prompt" in input:
- messages.insert(
- 0, {"role": "system", "content": input["system_prompt"]}
+ system_prompt: Union[str, None] = None
+ use_system_prompt: bool = Field(default=True, deprecated=True)
+ template: str = Field(
+ default="{{ instruction }}",
+ description=(
+ "This is a template or prompt to use for the generation. "
+ "If not provided, it is assumed a `instruction` is placed in the inputs, "
+ "to be used as is."
+ ),
+ )
+ columns: Union[str, List[str]] = Field(
+ default="instruction",
+ description=(
+ "Custom column or list of columns to include in the input. "
+ "If a `template` is provided which needs custom column names, "
+ "then they should be provided here. By default it will use `instruction`."
+ ),
+ )
+
+ _can_be_used_with_offline_batch_generation = True
+ _template: Optional["Template"] = PrivateAttr(default=...)
+
+ def model_post_init(self, __context: Any) -> None:
+ self.columns = [self.columns] if isinstance(self.columns, str) else self.columns
+ super().model_post_init(__context)
+
+ def load(self) -> None:
+ super().load()
+
+ def check_column_in_template(column, template):
+ pattern = (
+ r"(?:{%.*?\b"
+ + re.escape(column)
+ + r"\b.*?%}|{{\s*"
+ + re.escape(column)
+ + r"\s*}})"
+ )
+ if not re.search(pattern, template):
+ raise DistilabelUserError(
+ (
+ f"You required column name '{column}', but is not present in the template, "
+ "ensure the 'columns' match with the 'template' to avoid errors."
+ ),
+ page="components-gallery/tasks/textgeneration/",
)
- else:
- warnings.warn(
- "`use_system_prompt` is set to `True`, but no `system_prompt` in input batch, so it will be ignored.",
- UserWarning,
- stacklevel=2,
+
+ for column in self.columns:
+ check_column_in_template(column, self.template)
+
+ self._template = Template(self.template)
+
+ def unload(self) -> None:
+ super().unload()
+ self._template = None
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The input for the task is the `instruction` by default, or the `columns` given as input."""
+ columns = {column: True for column in self.columns}
+ columns["system_prompt"] = False
+ return columns
+
+ def _prepare_message_content(self, input: Dict[str, Any]) -> "ChatType":
+ """Prepares the content for the template and returns the formatted messages."""
+ fields = {column: input[column] for column in self.columns}
+ return [{"role": "user", "content": self._template.render(**fields)}]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType` assuming that the instruction
+ is the first interaction from the user within a conversation."""
+ # Handle the previous expected errors, in case of custom columns there's more freedom
+ # and we cannot check it so easily.
+ if self.columns == ["instruction"]:
+ if is_openai_format(input["instruction"]):
+ raise DistilabelUserError(
+ "Providing `instruction` formatted as an OpenAI chat / conversation is"
+ " deprecated, you should use `ChatGeneration` with `messages` as input instead.",
+ page="components-gallery/tasks/textgeneration/",
+ )
+
+ if not isinstance(input["instruction"], str):
+ raise DistilabelUserError(
+ f"Input `instruction` must be a string. Got: {input['instruction']}.",
+ page="components-gallery/tasks/textgeneration/",
)
+
+ messages = self._prepare_message_content(input)
+
+ row_system_prompt = input.get("system_prompt")
+ if row_system_prompt:
+ messages.insert(0, {"role": "system", "content": row_system_prompt})
+
+ if self.system_prompt and not row_system_prompt:
+ messages.insert(0, {"role": "system", "content": self.system_prompt})
+
return messages # type: ignore
@property
@@ -146,7 +321,6 @@ class ChatGeneration(Task):
`:material-chat:`
Examples:
-
Generate text from a conversation in OpenAI chat format:
```python
@@ -189,20 +363,22 @@ def inputs(self) -> List[str]:
"""The input for the task are the `messages`."""
return ["messages"]
- def format_input(self, input: Dict[str, Any]) -> ChatType:
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the messages provided
are already formatted that way i.e. following the OpenAI chat format."""
if not is_openai_format(input["messages"]):
- raise ValueError(
+ raise DistilabelUserError(
"Input `messages` must be an OpenAI chat-like format conversation. "
- f"Got: {input['messages']}. Please check: 'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'."
+ f"Got: {input['messages']}. Please check: 'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'.",
+ page="components-gallery/tasks/chatgeneration/",
)
if input["messages"][-1]["role"] != "user":
- raise ValueError(
+ raise DistilabelUserError(
"The last message must be from the user. Please check: "
- "'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'."
+ "'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'.",
+ page="components-gallery/tasks/chatgeneration/",
)
return input["messages"]
@@ -213,7 +389,7 @@ def outputs(self) -> List[str]:
return ["generation", "model_name"]
def format_output(
- self, output: Union[str, None], input: Dict[str, Any]
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""The output is formatted as a dictionary with the `generation`. The `model_name`
will be automatically included within the `process` method of `Task`."""
diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py
index ae9fd9519e..920a94c3b9 100644
--- a/src/distilabel/steps/tasks/typing.py
+++ b/src/distilabel/steps/tasks/typing.py
@@ -49,6 +49,8 @@ class OutlinesStructuredOutputType(TypedDict, total=False):
class InstructorStructuredOutputType(TypedDict, total=False):
"""TypedDict to represent the structured output configuration from `instructor`."""
+ format: Optional[Literal["json"]]
+ """One of "json"."""
schema: Union[Type[BaseModel], Dict[str, Any]]
"""The schema to use for the structured output, a `pydantic.BaseModel` class. """
mode: Optional[str]
diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py
index e8e98759ec..aeb57bda36 100644
--- a/src/distilabel/steps/tasks/ultrafeedback.py
+++ b/src/distilabel/steps/tasks/ultrafeedback.py
@@ -12,18 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.resources as importlib_resources
import re
-import sys
-
-if sys.version_info < (3, 9):
- import importlib_resources
-else:
- import importlib.resources as importlib_resources
-
from typing import Any, Dict, List, Literal, Optional, Union
+import orjson
from jinja2 import Template
from pydantic import PrivateAttr
+from typing_extensions import override
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
@@ -63,7 +59,6 @@ class UltraFeedback(Task):
- [`UltraFeedback - GitHub Repository`](https://github.com/OpenBMB/UltraFeedback)
Examples:
-
Rate generations from different LLMs based on the selected aspect:
```python
@@ -74,13 +69,14 @@ class UltraFeedback(Task):
ultrafeedback = UltraFeedback(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
- )
+ ),
+ use_default_structured_output=False
)
ultrafeedback.load()
result = next(
- chat.process(
+ ultrafeedback.process(
[
{
"instruction": "How much is 2+2?",
@@ -101,8 +97,83 @@ class UltraFeedback(Task):
# ]
```
- Citations:
+ Rate generations from different LLMs based on the honesty, using the default structured output:
+
+ ```python
+ from distilabel.steps.tasks import UltraFeedback
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ ultrafeedback = UltraFeedback(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ aspect="honesty"
+ )
+
+ ultrafeedback.load()
+ result = next(
+ ultrafeedback.process(
+ [
+ {
+ "instruction": "How much is 2+2?",
+ "generations": ["4", "and a car"],
+ }
+ ]
+ )
+ )
+ # result
+ # [{'instruction': 'How much is 2+2?',
+ # 'generations': ['4', 'and a car'],
+ # 'ratings': [5, 1],
+ # 'rationales': ['The response is correct and confident, as it directly answers the question without expressing any uncertainty or doubt.',
+ # "The response is confidently incorrect, as it provides unrelated information ('a car') and does not address the question. The model shows no uncertainty or indication that it does not know the answer."],
+ # 'distilabel_metadata': {'raw_output_ultra_feedback_0': '{"ratings": [\\n 5,\\n 1\\n] \\n\\n,"rationales": [\\n "The response is correct and confident, as it directly answers the question without expressing any uncertainty or doubt.",\\n "The response is confidently incorrect, as it provides unrelated information (\'a car\') and does not address the question. The model shows no uncertainty or indication that it does not know the answer."\\n] }'},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Rate generations from different LLMs based on the helpfulness, using the default structured output:
+
+ ```python
+ from distilabel.steps.tasks import UltraFeedback
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ ultrafeedback = UltraFeedback(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={"max_new_tokens": 512},
+ ),
+ aspect="helpfulness"
+ )
+
+ ultrafeedback.load()
+
+ result = next(
+ ultrafeedback.process(
+ [
+ {
+ "instruction": "How much is 2+2?",
+ "generations": ["4", "and a car"],
+ }
+ ]
+ )
+ )
+ # result
+ # [{'instruction': 'How much is 2+2?',
+ # 'generations': ['4', 'and a car'],
+ # 'ratings': [1, 5],
+ # 'rationales': ['Text 1 is clear and relevant, providing the correct answer to the question. It is also not lengthy and does not contain repetition. However, it lacks comprehensive information or detailed description.',
+ # 'Text 2 is neither clear nor relevant to the task. It does not provide any useful information and seems unrelated to the question.'],
+ # 'rationales_for_rating': ['Text 1 is rated as Correct (3) because it provides the accurate answer to the question, but lacks comprehensive information or detailed description.',
+ # 'Text 2 is rated as Severely Incorrect (1) because it does not provide any relevant information and seems unrelated to the question.'],
+ # 'types': [1, 3, 1],
+ # 'distilabel_metadata': {'raw_output_ultra_feedback_0': '{ \\n "ratings": [\\n 1,\\n 5\\n ]\\n ,\\n "rationales": [\\n "Text 1 is clear and relevant, providing the correct answer to the question. It is also not lengthy and does not contain repetition. However, it lacks comprehensive information or detailed description.",\\n "Text 2 is neither clear nor relevant to the task. It does not provide any useful information and seems unrelated to the question."\\n ]\\n ,\\n "rationales_for_rating": [\\n "Text 1 is rated as Correct (3) because it provides the accurate answer to the question, but lacks comprehensive information or detailed description.",\\n "Text 2 is rated as Severely Incorrect (1) because it does not provide any relevant information and seems unrelated to the question."\\n ]\\n ,\\n "types": [\\n 1, 3,\\n 1\\n ]\\n }'},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Citations:
```
@misc{cui2024ultrafeedbackboostinglanguagemodels,
title={UltraFeedback: Boosting Language Models with Scaled AI Feedback},
@@ -135,6 +206,7 @@ class UltraFeedback(Task):
)
)
_template: Optional["Template"] = PrivateAttr(default=...)
+ _can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template for the given `aspect`."""
@@ -185,7 +257,7 @@ def outputs(self) -> List[str]:
return columns + ["model_name"]
def format_output(
- self, output: Union[str, None], input: Dict[str, Any]
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""The output is formatted as a dictionary with the `ratings` and `rationales` for
each of the provided `generations` for the given `instruction`. The `model_name`
@@ -202,12 +274,15 @@ def format_output(
`ratings`, and `rationales-for-ratings` for each of the provided `generations` for the
given `instruction` if the provided aspect is either `helpfulness` or `truthfulness`.
"""
+ assert input is not None, "Input is required to format the output."
+
if self.aspect in [
"honesty",
"instruction-following",
"overall-rating",
]:
return self._format_ratings_rationales_output(output, input)
+
return self._format_types_ratings_rationales_output(output, input)
def _format_ratings_rationales_output(
@@ -220,6 +295,9 @@ def _format_ratings_rationales_output(
"rationales": [None] * len(input["generations"]),
}
+ if self.use_default_structured_output:
+ return self._format_structured_output(output, input)
+
pattern = r"Rating: (.+?)\nRationale: (.+)"
sections = output.split("\n\n")
@@ -234,9 +312,11 @@ def _format_ratings_rationales_output(
formatted_outputs.append(
{
- "ratings": int(re.findall(r"\b\d+\b", matches.group(1))[0])
- if matches.group(1) not in ["None", "N/A"]
- else None,
+ "ratings": (
+ int(re.findall(r"\b\d+\b", matches.group(1))[0])
+ if matches.group(1) not in ["None", "N/A"]
+ else None
+ ),
"rationales": matches.group(2),
}
)
@@ -254,6 +334,9 @@ def _format_types_ratings_rationales_output(
"rationales-for-ratings": [None] * len(input["generations"]),
}
+ if self.use_default_structured_output:
+ return self._format_structured_output(output, input)
+
pattern = r"Type: (.+?)\nRationale: (.+?)\nRating: (.+?)\nRationale: (.+)"
sections = output.split("\n\n")
@@ -276,14 +359,135 @@ def _format_types_ratings_rationales_output(
formatted_outputs.append(
{
- "types": int(re.findall(r"\b\d+\b", matches.group(1))[0])
- if matches.group(1) not in ["None", "N/A"]
- else None,
+ "types": (
+ int(re.findall(r"\b\d+\b", matches.group(1))[0])
+ if matches.group(1) not in ["None", "N/A"]
+ else None
+ ),
"rationales": matches.group(2),
- "ratings": int(re.findall(r"\b\d+\b", matches.group(3))[0])
- if matches.group(3) not in ["None", "N/A"]
- else None,
+ "ratings": (
+ int(re.findall(r"\b\d+\b", matches.group(3))[0])
+ if matches.group(3) not in ["None", "N/A"]
+ else None
+ ),
"rationales-for-ratings": matches.group(4),
}
)
return group_dicts(*formatted_outputs)
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from pydantic import BaseModel
+ from typing import List
+
+ class SchemaUltraFeedback(BaseModel):
+ ratings: List[int]
+ rationales: List[str]
+
+ class SchemaUltraFeedbackWithType(BaseModel):
+ types: List[Optional[int]]
+ ratings: List[int]
+ rationales: List[str]
+ rationales_for_rating: List[str]
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ if self.aspect in [
+ "honesty",
+ "instruction-following",
+ "overall-rating",
+ ]:
+ return {
+ "properties": {
+ "ratings": {
+ "items": {"type": "integer"},
+ "title": "Ratings",
+ "type": "array",
+ },
+ "rationales": {
+ "items": {"type": "string"},
+ "title": "Rationales",
+ "type": "array",
+ },
+ },
+ "required": ["ratings", "rationales"],
+ "title": "SchemaUltraFeedback",
+ "type": "object",
+ }
+ return {
+ "properties": {
+ "types": {
+ "items": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
+ "title": "Types",
+ "type": "array",
+ },
+ "ratings": {
+ "items": {"type": "integer"},
+ "title": "Ratings",
+ "type": "array",
+ },
+ "rationales": {
+ "items": {"type": "string"},
+ "title": "Rationales",
+ "type": "array",
+ },
+ "rationales_for_rating": {
+ "items": {"type": "string"},
+ "title": "Rationales For Rating",
+ "type": "array",
+ },
+ },
+ "required": ["types", "ratings", "rationales", "rationales_for_rating"],
+ "title": "SchemaUltraFeedbackWithType",
+ "type": "object",
+ }
+
+ def _format_structured_output(
+ self, output: str, input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Parses the structured response, which should correspond to a dictionary
+ with either `positive`, or `positive` and `negative` keys.
+
+ Args:
+ output: The output from the `LLM`.
+
+ Returns:
+ Formatted output.
+ """
+ try:
+ return orjson.loads(output)
+ except orjson.JSONDecodeError:
+ if self.aspect in [
+ "honesty",
+ "instruction-following",
+ "overall-rating",
+ ]:
+ return {
+ "ratings": [None] * len(input["generations"]),
+ "rationales": [None] * len(input["generations"]),
+ }
+ return {
+ "ratings": [None] * len(input["generations"]),
+ "rationales": [None] * len(input["generations"]),
+ "types": [None] * len(input["generations"]),
+ "rationales-for-ratings": [None] * len(input["generations"]),
+ }
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ {
+ "instruction": f"",
+ "generations": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py
new file mode 100644
index 0000000000..705b9c4883
--- /dev/null
+++ b/src/distilabel/steps/tasks/urial.py
@@ -0,0 +1,124 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+from typing import TYPE_CHECKING, Any, Dict, Union
+
+from jinja2 import Template
+
+from distilabel.steps.tasks import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+class URIAL(Task):
+ """Generates a response using a non-instruct fine-tuned model.
+
+ `URIAL` is a pre-defined task that generates a response using a non-instruct fine-tuned
+ model. This task is used to generate a response based on the conversation provided as
+ input.
+
+ Input columns:
+ - instruction (`str`, optional): The instruction to generate a response from.
+ - conversation (`List[Dict[str, str]]`, optional): The conversation to generate
+ a response from (the last message must be from the user).
+
+ Output columns:
+ - generation (`str`): The generated response.
+ - model_name (`str`): The name of the model used to generate the response.
+
+ Categories:
+ - text-generation
+
+ References:
+ - [The Unlocking Spell on Base LLMs: Rethinking Alignment via In-Context Learning](https://arxiv.org/abs/2312.01552)
+
+ Examples:
+ Generate text from an instruction:
+
+ ```python
+ from distilabel.llms import vLLM
+ from distilabel.steps.tasks import URIAL
+
+ step = URIAL(
+ llm=vLLM(
+ model="meta-llama/Meta-Llama-3.1-8B",
+ generation_kwargs={"temperature": 0.7},
+ ),
+ )
+
+ step.load()
+
+ results = next(
+ step.process(inputs=[{"instruction": "What's the most most common type of cloud?"}])
+ )
+ # [
+ # {
+ # 'instruction': "What's the most most common type of cloud?",
+ # 'generation': 'Clouds are classified into three main types, high, middle, and low. The most common type of cloud is the middle cloud.',
+ # 'distilabel_metadata': {...},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-8B'
+ # }
+ # ]
+ ```
+ """
+
+ def load(self) -> None:
+ """Loads the Jinja2 template for the given `aspect`."""
+ super().load()
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "urial.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def inputs(self) -> "StepColumns":
+ return {"instruction": False, "conversation": False}
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ messages = (
+ [{"role": "user", "content": input["instruction"]}]
+ if "instruction" in input
+ else input["conversation"]
+ )
+
+ if messages[-1]["role"] != "user":
+ raise ValueError("The last message must be from the user.")
+
+ return [{"role": "user", "content": self._template.render(messages=messages)}]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["generation", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ if output is None:
+ return {"generation": None}
+
+ response = output.split("\n\n# User")[0]
+ if response.startswith("\n\n"):
+ response = response[2:]
+ response = response.strip()
+
+ return {"generation": response}
diff --git a/src/distilabel/steps/truncate.py b/src/distilabel/steps/truncate.py
new file mode 100644
index 0000000000..6e68af6630
--- /dev/null
+++ b/src/distilabel/steps/truncate.py
@@ -0,0 +1,148 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+from typing import TYPE_CHECKING, Any, Callable, List, Optional
+
+from typing_extensions import override
+
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
+
+
+class TruncateTextColumn(Step):
+ """Truncate a row using a tokenizer or the number of characters.
+
+ `TruncateTextColumn` is a `Step` that truncates a row according to the max length. If
+ the `tokenizer` is provided, then the row will be truncated using the tokenizer,
+ and the `max_length` will be used as the maximum number of tokens, otherwise it will
+ be used as the maximum number of characters. The `TruncateTextColumn` step is useful when one
+ wants to truncate a row to a certain length, to avoid posterior errors in the model due
+ to the length.
+
+ Attributes:
+ column: the column to truncate. Defaults to `"text"`.
+ max_length: the maximum length to use for truncation.
+ If a `tokenizer` is given, corresponds to the number of tokens,
+ otherwise corresponds to the number of characters. Defaults to `8192`.
+ tokenizer: the name of the tokenizer to use. If provided, the row will be
+ truncated using the tokenizer. Defaults to `None`.
+
+ Input columns:
+ - dynamic (determined by `column` attribute): The columns to be truncated, defaults to "text".
+
+ Output columns:
+ - dynamic (determined by `column` attribute): The truncated column.
+
+ Categories:
+ - text-manipulation
+
+ Examples:
+ Truncating a row to a given number of tokens:
+
+ ```python
+ from distilabel.steps import TruncateTextColumn
+
+ trunc = TruncateTextColumn(
+ tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ max_length=4,
+ column="text"
+ )
+
+ trunc.load()
+
+ result = next(
+ trunc.process(
+ [
+ {"text": "This is a sample text that is longer than 10 characters"}
+ ]
+ )
+ )
+ # result
+ # [{'text': 'This is a sample'}]
+ ```
+
+ Truncating a row to a given number of characters:
+
+ ```python
+ from distilabel.steps import TruncateTextColumn
+
+ trunc = TruncateTextColumn(max_length=10)
+
+ trunc.load()
+
+ result = next(
+ trunc.process(
+ [
+ {"text": "This is a sample text that is longer than 10 characters"}
+ ]
+ )
+ )
+ # result
+ # [{'text': 'This is a '}]
+ ```
+ """
+
+ column: str = "text"
+ max_length: int = 8192
+ tokenizer: Optional[str] = None
+ _truncator: Optional[Callable[[str], str]] = None
+ _tokenizer: Optional[Any] = None
+
+ def load(self):
+ super().load()
+ if self.tokenizer:
+ if not importlib.util.find_spec("transformers"):
+ raise ImportError(
+ "`transformers` is needed to tokenize, but is not installed. "
+ "Please install it using `pip install transformers`."
+ )
+
+ from transformers import AutoTokenizer
+
+ self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
+ self._truncator = self._truncate_with_tokenizer
+ else:
+ self._truncator = self._truncate_with_length
+
+ @property
+ def inputs(self) -> List[str]:
+ return [self.column]
+
+ @property
+ def outputs(self) -> List[str]:
+ return self.inputs
+
+ def _truncate_with_length(self, text: str) -> str:
+ """Truncates the text according to the number of characters."""
+ return text[: self.max_length]
+
+ def _truncate_with_tokenizer(self, text: str) -> str:
+ """Truncates the text according to the number of characters using the tokenizer."""
+ return self._tokenizer.decode(
+ self._tokenizer.encode(
+ text,
+ add_special_tokens=False,
+ max_length=self.max_length,
+ truncation=True,
+ )
+ )
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput":
+ for input in inputs:
+ input[self.column] = self._truncator(input[self.column])
+ yield inputs
diff --git a/src/distilabel/steps/typing.py b/src/distilabel/steps/typing.py
index 3bdd2e2553..720037a74f 100644
--- a/src/distilabel/steps/typing.py
+++ b/src/distilabel/steps/typing.py
@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Iterator, List, Tuple
+from typing import Any, Dict, Iterator, List, Tuple, Union
StepOutput = Iterator[List[Dict[str, Any]]]
-"""StepOutput is an alias of the typing `Iterator[List[Dict[str, Any]]]`"""
+"""`StepOutput` is an alias of the typing `Iterator[List[Dict[str, Any]]]`"""
GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]]
-"""GeneratorStepOutput is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`"""
+"""`GeneratorStepOutput` is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`"""
+
+StepColumns = Union[List[str], Dict[str, bool]]
+"""`StepColumns` is an alias of the typing `Union[List[str], Dict[str, bool]]` used by the
+`inputs` and `outputs` properties of an `Step`. In the case of a `List[str]`, it is a list
+with the required columns. In the case of a `Dict[str, bool]`, it is a dictionary where
+the keys are the columns and the values are booleans indicating whether the column is
+required or not.
+"""
diff --git a/src/distilabel/utils/card/distilabel_template.md b/src/distilabel/utils/card/distilabel_template.md
index c51e9a5953..38daa7f857 100644
--- a/src/distilabel/utils/card/distilabel_template.md
+++ b/src/distilabel/utils/card/distilabel_template.md
@@ -70,6 +70,21 @@ ds = load_dataset("{{ repo_id }}")
{% endfor %}
+{% if artifacts %}
+## Artifacts
+
+{% for step_name, artifacts in artifacts.items() %}
+* **Step**: `{{ step_name }}`
+ {% for artifact in artifacts %}
+ * **Artifact name**: `{{ artifact.name }}`
+ {% for name, value in artifact.metadata.items() %}
+ * `{{ name }}`: {{ value }}
+ {% endfor %}
+ {% endfor %}
+{% endfor %}
+
+{% endif %}
+
{% if references %}
## References
diff --git a/src/distilabel/utils/docstring.py b/src/distilabel/utils/docstring.py
index d33b503b8a..6171b72b2e 100644
--- a/src/distilabel/utils/docstring.py
+++ b/src/distilabel/utils/docstring.py
@@ -165,7 +165,8 @@ def parse_google_docstring(func: Callable) -> Docstring: # noqa: C901
elif section_name == "examples":
# Parse examples into a dictionary
example_items = re.findall(
- r"(\w[\w\s]*?):\s*\n\s*```python\n(.*?)\n\s*```",
+ r"""([\w,()'][\w\s,()=`!'"]*?):\s*\n?\s*```python\n(.*?)\n\s*```""",
+ # r"(\w[\w\s]*?):\s*\n?\s*```python\n(.*?)\n\s*```",
section_content,
re.DOTALL,
)
@@ -217,7 +218,6 @@ def get_bibtex(ref: str) -> str:
The bibtex style citation.
Examples:
-
```python
cite = get_bibtex(r"https://arxiv.org/abs/2406.18518")
@misc{other,
@@ -236,9 +236,7 @@ def get_bibtex(ref: str) -> str:
from bs4 import BeautifulSoup
if not ref.startswith("https://arxiv.org"):
- raise ValueError(
- f"The url must start with of `https://arxiv.org`, but got: {ref}"
- )
+ raise ValueError(f"The url must start with `https://arxiv.org`, but got: {ref}")
response: bytes = requests.get(
rf"https://arxiv2bibtex.org/?q={quote_plus(ref)}&format=bibtex"
)
diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py
index 2555f3b262..34accced2b 100644
--- a/src/distilabel/utils/itertools.py
+++ b/src/distilabel/utils/itertools.py
@@ -12,11 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import sys
from itertools import zip_longest
from typing import Any, Iterable, Literal, Tuple, TypeVar
T = TypeVar("T")
+# https://docs.python.org/3/library/itertools.html#itertools.batched
+if sys.version_info >= (3, 12):
+ from itertools import batched
+else:
+ from itertools import islice
+
+ def batched(iterable: Iterable[T], n: int) -> Iterable[T]:
+ # batched('ABCDEFG', 3) → ABC DEF G
+ if n < 1:
+ raise ValueError("n must be at least one")
+ iterator = iter(iterable)
+ while batch := tuple(islice(iterator, n)):
+ yield batch
+
# Copy pasted from https://docs.python.org/3/library/itertools.html#itertools-recipes
# Just added the type hints and use `if`s instead of `match`
diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py
index 0e409863ff..994c81e321 100644
--- a/src/distilabel/utils/logging.py
+++ b/src/distilabel/utils/logging.py
@@ -14,15 +14,16 @@
import logging
import multiprocessing as mp
-import os
import warnings
from logging import FileHandler
from logging.handlers import QueueHandler, QueueListener
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Union
from rich.logging import RichHandler
+from distilabel import envs
+
if TYPE_CHECKING:
from queue import Queue
@@ -46,7 +47,9 @@
def setup_logging(
- log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None
+ log_queue: Optional["Queue[Any]"] = None,
+ filename: Optional[str] = None,
+ logging_handlers: Optional[List[logging.Handler]] = None,
) -> None:
"""Sets up logging to use a queue across all processes."""
global queue_listener
@@ -59,25 +62,30 @@ def setup_logging(
# If the current process is the main process, set up a `QueueListener`
# to handle logs from all subprocesses
if mp.current_process().name == "MainProcess" and filename:
+ if logging_handlers is None:
+ logging_handlers = []
+
formatter = logging.Formatter("['%(name)s'] %(message)s")
handler = RichHandler(rich_tracebacks=True)
handler.setFormatter(formatter)
+ logging_handlers.append(handler)
+
if not Path(filename).parent.exists():
Path(filename).parent.mkdir(parents=True, exist_ok=True)
-
- file_handler = FileHandler(filename, delay=True)
+ file_handler = FileHandler(filename, delay=True, encoding="utf-8")
file_formatter = logging.Formatter(
"[%(asctime)s] %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_formatter)
+ logging_handlers.append(file_handler)
if log_queue is not None:
queue_listener = QueueListener(
- log_queue, handler, file_handler, respect_handler_level=True
+ log_queue, *logging_handlers, respect_handler_level=True
)
queue_listener.start()
- log_level = os.environ.get("DISTILABEL_LOG_LEVEL", "INFO").upper()
+ log_level = envs.DISTILABEL_LOG_LEVEL
if log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
warnings.warn(
f"Invalid log level '{log_level}', using default 'INFO' instead.",
diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py
index 2d28c55fb6..621f4b61dc 100644
--- a/src/distilabel/utils/mkdocs/components_gallery.py
+++ b/src/distilabel/utils/mkdocs/components_gallery.py
@@ -16,6 +16,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Union
+import pandas as pd
from jinja2 import Template
from mkdocs.config.base import Config
from mkdocs.config.config_options import Type
@@ -76,15 +77,68 @@
_STEPS_CATEGORY_TO_ICON = {
"text-generation": ":material-text-box-edit:",
+ "chat-generation": ":material-chat:",
+ "text-classification": ":material-label:",
+ "text-manipulation": ":material-receipt-text-edit:",
"evol": ":material-dna:",
- "preference": ":material-poll:",
"critique": ":material-comment-edit:",
"scorer": ":octicons-number-16:",
+ "preference": ":material-poll:",
"embedding": ":material-vector-line:",
- "format": ":material-format-list-bulleted:",
+ "clustering": ":material-scatter-plot:",
+ "columns": ":material-table-column:",
"filtering": ":material-filter:",
- "save": ":material-content-save:",
+ "format": ":material-format-list-bulleted:",
"load": ":material-file-download:",
+ "execution": ":octicons-code-16:",
+ "save": ":material-content-save:",
+}
+
+_STEP_CATEGORY_TO_DESCRIPTION = {
+ "text-generation": "Text generation steps are used to generate text based on a given prompt.",
+ "chat-generation": "Chat generation steps are used to generate text based on a conversation.",
+ "text-classification": "Text classification steps are used to classify text into a category.",
+ "text-manipulation": "Text manipulation steps are used to manipulate or rewrite an input text.",
+ "evol": "Evol steps are used to rewrite input text and evolve it to a higher quality.",
+ "critique": "Critique steps are used to provide feedback on the quality of the data with a written explanation.",
+ "scorer": "Scorer steps are used to evaluate and score the data with a numerical value.",
+ "preference": "Preference steps are used to collect preferences on the data with numerical values or ranks.",
+ "embedding": "Embedding steps are used to generate embeddings for the data.",
+ "clustering": "Clustering steps are used to group similar data points together.",
+ "columns": "Columns steps are used to manipulate columns in the data.",
+ "filtering": "Filtering steps are used to filter the data based on some criteria.",
+ "format": "Format steps are used to format the data.",
+ "load": "Load steps are used to load the data.",
+ "execution": "Executes python functions.",
+ "save": "Save steps are used to save the data.",
+}
+
+assert list(_STEP_CATEGORY_TO_DESCRIPTION.keys()) == list(
+ _STEPS_CATEGORY_TO_ICON.keys()
+)
+
+_STEP_CATEGORIES = list(_STEP_CATEGORY_TO_DESCRIPTION.keys())
+_STEP_CATEGORY_TABLE = pd.DataFrame(
+ {
+ "Icon": [_STEPS_CATEGORY_TO_ICON[category] for category in _STEP_CATEGORIES],
+ "Category": _STEP_CATEGORIES,
+ "Description": [
+ _STEP_CATEGORY_TO_DESCRIPTION[category] for category in _STEP_CATEGORIES
+ ],
+ }
+).to_markdown(index=False)
+_STEP_CATEGORY_TABLE_DESCRIPTION = [
+ '??? info "Category Overview"',
+ " The gallery page showcases the different types of components within `distilabel`.",
+ "",
+]
+for row in _STEP_CATEGORY_TABLE.split("\n"):
+ _STEP_CATEGORY_TABLE_DESCRIPTION.append(f" {row}")
+_STEP_CATEGORY_TABLE_DESCRIPTION = "\n".join(_STEP_CATEGORY_TABLE_DESCRIPTION)
+
+_CATEGORY_ORDER_INDEX = {
+ category: idx
+ for idx, category in enumerate(list(_STEP_CATEGORY_TO_DESCRIPTION.keys()))
}
@@ -209,6 +263,18 @@ def _generate_steps_pages(self, src_dir: Path, steps: list) -> List[str]:
steps_gallery_page_path = src_dir / paths[0]
steps_gallery_page_path.parent.mkdir(parents=True, exist_ok=True)
+ # Sort steps based on the index of their first category in the 'category_order'
+ steps = sorted(
+ steps,
+ key=lambda step: _CATEGORY_ORDER_INDEX.get(
+ step["docstring"]["categories"][0]
+ if step["docstring"]["categories"]
+ else float("inf"),
+ float("inf"),
+ ),
+ reverse=True,
+ )
+
# Create detail page for each `Step`
for step in steps:
docstring = step["docstring"]
@@ -216,6 +282,11 @@ def _generate_steps_pages(self, src_dir: Path, steps: list) -> List[str]:
first_category = docstring["categories"][0]
docstring["icon"] = _STEPS_CATEGORY_TO_ICON.get(first_category, "")
+ if docstring["icon"]:
+ assert (
+ docstring["icon"] in _STEPS_CATEGORY_TO_ICON.values()
+ ), f"Icon {docstring['icon']} not found in _STEPS_CATEGORY_TO_ICON"
+
name = step["name"]
content = _STEP_DETAIL_TEMPLATE.render(
@@ -234,10 +305,10 @@ def _generate_steps_pages(self, src_dir: Path, steps: list) -> List[str]:
paths.append(step_path)
- # Create the `components-gallery/steps.md` file
+ # Create the `components-gallery/steps/index.md` file
content = _COMPONENTS_LIST_TEMPLATE.render(
title="Steps Gallery",
- description="",
+ description=_STEP_CATEGORY_TABLE_DESCRIPTION,
components=steps,
default_icon=":material-step-forward:",
)
@@ -262,12 +333,27 @@ def _generate_tasks_pages(self, src_dir: Path, tasks: list) -> List[str]:
tasks_gallery_page_path = src_dir / paths[0]
tasks_gallery_page_path.parent.mkdir(parents=True, exist_ok=True)
+ # Sort tasks based on the index of their first category in the 'category_order'
+ tasks = sorted(
+ tasks,
+ key=lambda task: _CATEGORY_ORDER_INDEX.get(
+ task["docstring"]["categories"][0]
+ if task["docstring"]["categories"]
+ else float("inf"),
+ float("inf"),
+ ),
+ )
+
# Create detail page for each `Task`
for task in tasks:
docstring = task["docstring"]
if docstring["icon"] == "" and docstring["categories"]:
first_category = docstring["categories"][0]
docstring["icon"] = _STEPS_CATEGORY_TO_ICON.get(first_category, "")
+ if docstring["icon"]:
+ assert (
+ docstring["icon"] in _STEPS_CATEGORY_TO_ICON.values()
+ ), f"Icon {docstring['icon']} not found in _STEPS_CATEGORY_TO_ICON"
name = task["name"]
@@ -287,10 +373,10 @@ def _generate_tasks_pages(self, src_dir: Path, tasks: list) -> List[str]:
paths.append(task_path)
- # Create the `components-gallery/steps/index.md` file
+ # Create the `components-gallery/tasks/index.md` file
content = _COMPONENTS_LIST_TEMPLATE.render(
title="Tasks Gallery",
- description="",
+ description=_STEP_CATEGORY_TABLE_DESCRIPTION,
components=tasks,
default_icon=":material-check-outline:",
)
diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py
index 8f32afc2eb..873ff20721 100644
--- a/src/distilabel/utils/serialization.py
+++ b/src/distilabel/utils/serialization.py
@@ -19,6 +19,8 @@
import orjson
+from distilabel.mixins.runtime_parameters import RuntimeParametersMixin
+
if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
@@ -205,7 +207,11 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"_values": {x.name: x.value for x in v}, # type: ignore
}
elif isinstance(v, list):
- dump[k] = {str(i): list_v for i, list_v in enumerate(v)}
+ obj_list = getattr(obj, k)
+ if isinstance(obj_list, list) and isinstance(
+ obj_list[0], RuntimeParametersMixin
+ ):
+ dump[k] = {str(i): list_v for i, list_v in enumerate(v)}
# Grab the fields that need extra care (LLMs from inside tasks)
to_update = _extra_serializable_fields(obj)
diff --git a/src/distilabel/utils/typing_.py b/src/distilabel/utils/typing_.py
index ac75aba63c..26dfa1b203 100644
--- a/src/distilabel/utils/typing_.py
+++ b/src/distilabel/utils/typing_.py
@@ -13,8 +13,9 @@
# limitations under the License.
import inspect
-from typing import Any
+from typing import Any, Union
+from pydantic.types import _SecretField
from typing_extensions import Annotated, get_args, get_origin
@@ -38,3 +39,31 @@ def is_parameter_annotated_with(parameter: inspect.Parameter, annotation: Any) -
return True
return False
+
+
+def extract_annotation_inner_type(type_hint: Any) -> Any:
+ """Extracts the inner type of an annotation.
+
+ Args:
+ type_hint: The type hint to extract the inner type from.
+
+ Returns:
+ The inner type of the `RuntimeParameter` type hint.
+ """
+ type_hint_args = get_args(type_hint)
+ if get_origin(type_hint) is Annotated:
+ return extract_annotation_inner_type(type_hint_args[0])
+
+ if get_origin(type_hint) is Union and type(None) in type_hint_args:
+ return extract_annotation_inner_type(type_hint_args[0])
+
+ return type_hint
+
+
+def is_type_pydantic_secret_field(type_: type) -> bool:
+ """Checks if a type is a Pydantic `_SecretField`.
+
+ Returns:
+ `True` if the type is a Pydantic `_SecretField`, `False` otherwise.
+ """
+ return inspect.isclass(type_) and issubclass(type_, _SecretField)
diff --git a/tests/integration/test_caching_steps.py b/tests/integration/test_caching_steps.py
new file mode 100644
index 0000000000..5ed8af993f
--- /dev/null
+++ b/tests/integration/test_caching_steps.py
@@ -0,0 +1,499 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import TemporaryDirectory
+from typing import TYPE_CHECKING, Any, Dict, Generator, List
+from unittest import mock
+from uuid import uuid4
+
+from pydantic import PrivateAttr
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps import LoadDataFromDicts
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from distilabel.pipeline.batch import _Batch
+
+
+class DummyStep(Step):
+ attr: int = 5
+ do_fail: bool = False
+ _ctr: int = PrivateAttr(default=0)
+
+ _random: str = PrivateAttr(default="")
+
+ def load(self) -> None:
+ super().load()
+ self._random = str(uuid4())
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["instruction"]
+
+ def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]:
+ for input in inputs:
+ input["response"] = f"I don't know - {self._ctr} - {self._random}"
+ self._ctr += 1
+
+ if self.do_fail:
+ raise ValueError("The step failed")
+ yield inputs
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["response"]
+
+
+class DummyStep2(DummyStep):
+ def process(
+ self, *inputs: StepInput
+ ) -> Generator[List[Dict[str, Any]], None, None]:
+ outputs = []
+ for input_a, input_b in zip(*inputs):
+ output = {**input_a, **input_b}
+ output["response"] = f"I don't know - {self._ctr}"
+ self._ctr += 1
+ outputs.append(output)
+ yield outputs
+
+
+class OtherDummyStep(DummyStep):
+ pass
+
+
+def test_cache() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_0 = pipeline.run()
+ distiset_1 = pipeline.run()
+ assert (
+ distiset_0["default"]["train"].to_list()
+ == distiset_1["default"]["train"].to_list()
+ )
+
+ distiset_2 = pipeline.run(use_cache=False)
+ assert len(distiset_2["default"]["train"]) == 48
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_2["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_step_cache_false() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=False,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` has been executed
+ assert run_step_spy.call_count == 1
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_step_changing() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ attr=103401234, # change attribute so step is not the same
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` has been executed
+ assert run_step_spy.call_count == 1
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_intermediate_step_cache_false() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=False,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_0 = pipeline.run()
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` and `step_c` has been executed
+ assert run_step_spy.call_count == 2
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_adding_step() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_c` has been executed
+ assert run_step_spy.call_count == 1
+
+ dict_0 = distiset_0["default"]["train"].to_dict()
+ dict_1 = distiset_1["default"]["train"].to_dict()
+ del dict_1["response_2"]
+ assert dict_0 == dict_1
+
+
+def test_cache_adding_step_with_multiple_predecessor() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ output_mappings={"response": "response_1"},
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ output_mappings={"response": "response_2"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> [step_a, step_b]
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ output_mappings={"response": "response_1"},
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ output_mappings={"response": "response_2"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep2(
+ name="step_c",
+ input_batch_size=12,
+ output_mappings={"response": "response_3"},
+ use_cache=True,
+ )
+
+ step_generator >> [step_a, step_b] >> step_c
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_c` has been executed
+ assert run_step_spy.call_count == 1
+
+ for row_1, row_0_a, row_0_b in zip(
+ distiset_1["default"]["train"],
+ distiset_0["step_a"]["train"],
+ distiset_0["step_b"]["train"],
+ ):
+ assert row_1["response_1"] == row_0_a["response_1"]
+ assert row_1["response_2"] == row_0_b["response_2"]
+
+
+def test_cache_with_offset() -> None:
+ use_cache_per_step = True
+ do_fail = False
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline_0:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a", input_batch_size=4, use_cache=use_cache_per_step
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=do_fail,
+ use_cache=use_cache_per_step,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=use_cache_per_step,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ # Controlled failure of the Pipeline
+ original_process_batch = pipeline_0._process_batch
+
+ def _process_batch_wrapper(
+ batch: "_Batch", send_last_batch_flag: bool = True
+ ) -> None:
+ if batch.step_name == step_b.name and batch.seq_no == 2:
+ pipeline_0._stop_called = True
+ original_process_batch(batch)
+
+ # Run first time and stop the pipeline when specific batch received (simulate CTRL + C)
+ with mock.patch.object(pipeline_0, "_process_batch", _process_batch_wrapper):
+ distiset_0 = pipeline_0.run(use_cache=False)
+
+ assert len(distiset_0["default"]["train"]) == 12
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline_1:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a", input_batch_size=4, use_cache=use_cache_per_step
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=do_fail,
+ use_cache=use_cache_per_step,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=use_cache_per_step,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_1 = pipeline_1.run()
+
+ assert len(distiset_1["default"]["train"]) == 48
diff --git a/tests/integration/test_deduplication.py b/tests/integration/test_deduplication.py
new file mode 100644
index 0000000000..0121550f54
--- /dev/null
+++ b/tests/integration/test_deduplication.py
@@ -0,0 +1,50 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps import LoadDataFromDicts, MinHashDedup
+
+
+def test_minhash_deduplication() -> None:
+ with Pipeline() as pipeline:
+ ds_size = 1000
+ batch_size = 500
+ data = LoadDataFromDicts(
+ data=[
+ {"text": "This is a test document."},
+ {"text": "This document is a test."},
+ {"text": "Test document for duplication."},
+ {"text": "Document for duplication test."},
+ {"text": "This is another unique document."},
+ ]
+ * (ds_size // 5),
+ batch_size=batch_size,
+ )
+ minhash = MinHashDedup(
+ tokenizer="ngrams",
+ n=2,
+ threshold=0.9,
+ storage="disk",
+ input_batch_size=batch_size,
+ )
+ data >> minhash
+
+ distiset = pipeline.run(use_cache=False)
+ ds = distiset["default"]["train"]
+ ds_dedup = ds.filter(lambda x: x["keep_row_after_minhash_filtering"])
+ assert len(ds_dedup) == 4
+
+
+if __name__ == "__main__":
+ test_minhash_deduplication()
diff --git a/tests/integration/test_embedding_dedup.py b/tests/integration/test_embedding_dedup.py
new file mode 100644
index 0000000000..7806cf6761
--- /dev/null
+++ b/tests/integration/test_embedding_dedup.py
@@ -0,0 +1,130 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+import faiss
+import numpy as np
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps import FaissNearestNeighbour, LoadDataFromDicts, StepInput, step
+from distilabel.steps.filtering.embedding import EmbeddingDedup
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepOutput
+
+
+SAMPLE_DATA = [
+ {
+ "text": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.",
+ "embedding": [
+ 0.018477669046149742,
+ -0.03748236608841726,
+ 0.001919870620352492,
+ 0.024918478063770535,
+ 0.02348063521315178,
+ 0.0038251285566308375,
+ -0.01723884983037716,
+ 0.02881971942372201,
+ ],
+ },
+ {
+ "text": "A music teacher or instructor focused on theoretical and practical piano lessons.",
+ "embedding": [
+ -0.0023464179614082125,
+ -0.07325472251663565,
+ -0.06058678419516501,
+ -0.02100326928586996,
+ -0.013462744792362657,
+ 0.027368447064244242,
+ -0.003916070100455717,
+ 0.01243614518480423,
+ ],
+ },
+ {
+ "text": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
+ "embedding": [
+ -0.01630817942328242,
+ -0.023760151552345232,
+ -0.014249650090627883,
+ -0.005713686451446624,
+ -0.016033059279131567,
+ 0.0071440908501058786,
+ -0.05691099643425161,
+ 0.01597412704817784,
+ ],
+ },
+ {
+ "text": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
+ "embedding": [
+ -0.01630817942328242,
+ -0.023760151552345232,
+ -0.014249650090627883,
+ -0.005713686451446624,
+ -0.016033059279131567,
+ 0.0071440908501058786,
+ -0.05691099643425161,
+ 0.01597412704817784,
+ ],
+ },
+]
+
+
+@step(inputs=["embedding"], outputs=["embedding"])
+def NormalizeEmbeddings(inputs: StepInput) -> "StepOutput":
+ # Normalize a vector to have length 1
+ for input in inputs:
+ norm = np.linalg.norm(input["embedding"])
+ if norm == 0:
+ print("Cannot normalize a zero vector")
+ continue
+ input["embedding"] = input["embedding"] / norm
+ yield inputs
+
+
+def test_embedding_deduplication() -> None:
+ with Pipeline() as pipeline:
+ loader = LoadDataFromDicts(
+ data=SAMPLE_DATA * 20,
+ batch_size=50,
+ )
+ batch_size = 50
+
+ # NOTE: Guide to choose an index: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
+ nn = FaissNearestNeighbour(
+ k=3,
+ metric_type=faiss.METRIC_INNER_PRODUCT,
+ search_batch_size=50,
+ # string_factory="IVF300_HNSW32,Flat",
+ # train_size=len(dataset),
+ input_batch_size=batch_size,
+ )
+
+ embedding_dedup = EmbeddingDedup(
+ threshold=0.99,
+ input_batch_size=batch_size,
+ )
+ normalize = NormalizeEmbeddings()
+ loader >> normalize >> nn >> embedding_dedup
+
+ distiset = pipeline.run(use_cache=False)
+
+ ds = distiset["default"]["train"]
+ ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"])
+ print(len(ds_dedup))
+ assert len(ds_dedup) == 71
+
+
+if __name__ == "__main__":
+ test_embedding_deduplication()
diff --git a/tests/integration/test_generator_and_sampler.py b/tests/integration/test_generator_and_sampler.py
new file mode 100644
index 0000000000..1bb0a457b5
--- /dev/null
+++ b/tests/integration/test_generator_and_sampler.py
@@ -0,0 +1,55 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.llms._dummy import DummyAsyncLLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import CombineOutputs, LoadDataFromDicts
+from distilabel.steps.generators.data_sampler import DataSampler
+from distilabel.steps.tasks import TextGeneration
+
+
+def get_pipeline():
+ with Pipeline() as pipe:
+ size_dataset_1 = 10
+ loader_1 = LoadDataFromDicts(
+ data=[{"instruction": f"instruction {i}"} for i in range(size_dataset_1)]
+ )
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ size=2,
+ samples=size_dataset_1,
+ batch_size=8,
+ )
+ text_generation = TextGeneration(llm=DummyAsyncLLM(), input_batch_size=8)
+
+ combine = CombineOutputs()
+ [loader_1, sampler] >> combine >> text_generation
+ return pipe
+
+
+def test_sampler():
+ pipe = get_pipeline()
+ distiset = pipe.run(use_cache=False)
+ assert len(distiset["default"]["train"]) == 10
+ row = distiset["default"]["train"][0]
+ assert isinstance(row["sample"], list)
+ assert len(row["sample"]) == 2
+ assert isinstance(row["instruction"], str)
+
+
+if __name__ == "__main__":
+ pipe = get_pipeline()
+ distiset = pipe.run(use_cache=False)
+ print(distiset)
+ print(distiset["default"]["train"][0])
diff --git a/tests/integration/test_load_stages.py b/tests/integration/test_load_stages.py
index 2079f32ea1..9faa771d77 100644
--- a/tests/integration/test_load_stages.py
+++ b/tests/integration/test_load_stages.py
@@ -155,10 +155,12 @@ def test_load_stages_status_load_from_cache() -> None:
original_process_batch = pipeline._process_batch
- def _process_batch_wrapper(batch: "_Batch") -> None:
+ def _process_batch_wrapper(
+ batch: "_Batch", send_last_batch_flag: bool = True
+ ) -> None:
if batch.step_name == group_1.name and batch.seq_no == 10:
pipeline._stop_called = True
- original_process_batch(batch)
+ original_process_batch(batch, send_last_batch_flag)
# Run first time and stop the pipeline when specific batch received (simulate CTRL + C)
with mock.patch.object(pipeline, "_process_batch", _process_batch_wrapper):
@@ -167,7 +169,3 @@ def _process_batch_wrapper(batch: "_Batch") -> None:
distiset = pipeline.run(use_cache=True)
assert len(distiset["default"]["train"]) == 1000
-
-
-if __name__ == "__main__":
- test_load_stages_status_load_from_cache()
diff --git a/tests/integration/test_multiple_replicas.py b/tests/integration/test_multiple_replicas.py
index 59950a4374..26d0f19b57 100644
--- a/tests/integration/test_multiple_replicas.py
+++ b/tests/integration/test_multiple_replicas.py
@@ -14,20 +14,17 @@
import random
import time
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING
-from distilabel.pipeline import Pipeline, routing_batch_function
+import pytest
+
+from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, StepInput, StepResources, step
if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput
-@routing_batch_function()
-def random_routing_batch(steps: List[str]) -> List[str]:
- return random.sample(steps, 2)
-
-
@step(outputs=["generation"])
def Generate(inputs: StepInput) -> "StepOutput":
# random sleep to simulate processing time
@@ -57,6 +54,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput":
yield combined_list
+@pytest.mark.xfail
def test_multiple_replicas() -> None:
with Pipeline(name="test") as pipeline:
load_dataset = LoadDataFromDicts(
diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py
new file mode 100644
index 0000000000..a9fe880ff7
--- /dev/null
+++ b/tests/integration/test_offline_batch_generation.py
@@ -0,0 +1,77 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import TemporaryDirectory
+from typing import TYPE_CHECKING, Any, List, Union
+
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
+from distilabel.llms import LLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import LoadDataFromDicts
+from distilabel.steps.tasks import TextGeneration
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+
+class DummyOfflineBatchGenerateLLM(LLM):
+ def load(self) -> None:
+ super().load()
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate( # type: ignore
+ self, inputs: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+
+ def offline_batch_generate(
+ self,
+ inputs: Union[List["FormattedInput"], None] = None,
+ num_generations: int = 1,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ # Simulate that the first time we create the jobs
+ if not self.jobs_ids:
+ self.jobs_ids = ("1234", "5678")
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids # type: ignore
+ )
+
+ return [["output" for _ in range(num_generations)]]
+
+
+def test_offline_batch_generation() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(cache_dir=tmp_dir) as pipeline:
+ load_data = LoadDataFromDicts(
+ data=[{"instruction": f"{i} instruction"} for i in range(100)]
+ )
+
+ text_generation = TextGeneration(
+ llm=DummyOfflineBatchGenerateLLM(use_offline_batch_generation=True)
+ )
+
+ load_data >> text_generation
+
+ distiset = pipeline.run()
+
+ # First call no results
+ assert len(distiset) == 0
+
+ distiset = pipeline.run(use_cache=True)
+ assert len(distiset) == 1
diff --git a/tests/integration/test_prints.py b/tests/integration/test_prints.py
new file mode 100644
index 0000000000..7db85caf8f
--- /dev/null
+++ b/tests/integration/test_prints.py
@@ -0,0 +1,72 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Union
+
+import pytest
+
+from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.steps import tasks as tasks_
+from tests.unit.conftest import DummyLLM
+
+# The tasks not listed here don't have a print method (or don't have a print method that works)
+tasks = [
+ tasks_.ComplexityScorer,
+ partial(tasks_.EvolInstruct, num_evolutions=1),
+ partial(tasks_.EvolComplexity, num_evolutions=1),
+ partial(tasks_.EvolComplexityGenerator, num_instructions=1),
+ partial(tasks_.EvolInstructGenerator, num_instructions=1),
+ partial(tasks_.EvolQuality, num_evolutions=1),
+ tasks_.Genstruct,
+ partial(
+ tasks_.BitextRetrievalGenerator,
+ source_language="English",
+ target_language="Spanish",
+ unit="sentence",
+ difficulty="elementary school",
+ high_score="4",
+ low_score="2.5",
+ ),
+ partial(tasks_.EmbeddingTaskGenerator, category="text-retrieval"),
+ tasks_.GenerateLongTextMatchingData,
+ tasks_.GenerateShortTextMatchingData,
+ tasks_.GenerateTextClassificationData,
+ tasks_.GenerateTextRetrievalData,
+ tasks_.MonolingualTripletGenerator,
+ tasks_.InstructionBacktranslation,
+ tasks_.Magpie,
+ tasks_.MagpieGenerator,
+ partial(tasks_.PrometheusEval, mode="absolute", rubric="factual-validity"),
+ tasks_.QualityScorer,
+ tasks_.SelfInstruct,
+ partial(tasks_.GenerateSentencePair, action="paraphrase"),
+ tasks_.UltraFeedback,
+ tasks_.URIAL,
+]
+
+
+class TestLLM(DummyLLM, MagpieChatTemplateMixin):
+ magpie_pre_query_template: Union[str, None] = "llama3"
+
+
+llm = TestLLM()
+
+
+@pytest.mark.parametrize("task", tasks)
+def test_prints(task) -> None:
+ t = task(llm=llm)
+ t.load()
+ t.print()
+ t.unload()
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 9884aa76cd..8c7c240b09 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, List
+from typing import TYPE_CHECKING, Any, Dict, List, Union
import pytest
from distilabel.llms.base import LLM, AsyncLLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
- from distilabel.steps.tasks.typing import FormattedInput
+ from distilabel.steps.tasks.typing import ChatType, FormattedInput
# Defined here too, so that the serde still works
-class DummyLLM(AsyncLLM):
+class DummyAsyncLLM(AsyncLLM):
+ structured_output: Any = None
+
def load(self) -> None:
pass
@@ -33,12 +36,28 @@ def load(self) -> None:
def model_name(self) -> str:
return "test"
- async def agenerate(
+ async def agenerate( # type: ignore
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
return ["output" for _ in range(num_generations)]
+class DummyLLM(LLM):
+ structured_output: Any = None
+
+ def load(self) -> None:
+ super().load()
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate( # type: ignore
+ self, inputs: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+
+
class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
def load(self) -> None:
pass
@@ -55,6 +74,31 @@ def generate(
]
+class DummyTask(Task):
+ @property
+ def inputs(self) -> List[str]:
+ return ["instruction", "additional_info"]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ return [
+ {"role": "system", "content": ""},
+ {"role": "user", "content": input["instruction"]},
+ ]
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["output", "info_from_input"]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ return {"output": output, "info_from_input": input["additional_info"]} # type: ignore
+
+
+class DummyTaskOfflineBatchGeneration(DummyTask):
+ _can_be_used_with_offline_batch_generation = True
+
+
@pytest.fixture
def dummy_llm() -> AsyncLLM:
- return DummyLLM()
+ return DummyAsyncLLM()
diff --git a/tests/unit/embeddings/test_vllm.py b/tests/unit/embeddings/test_vllm.py
new file mode 100644
index 0000000000..8291f434e9
--- /dev/null
+++ b/tests/unit/embeddings/test_vllm.py
@@ -0,0 +1,50 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock, Mock
+
+from distilabel.embeddings.vllm import vLLMEmbeddings
+
+
+# @patch("vllm.entrypoints.LLM")
+class TestSentenceTransformersEmbeddings:
+ model_name = "group/model-name"
+
+ def test_model_name(self) -> None:
+ embeddings = vLLMEmbeddings(model=self.model_name)
+
+ assert embeddings.model_name == self.model_name
+
+ def test_encode(self) -> None:
+ embeddings = vLLMEmbeddings(model=self.model_name)
+
+ # the loading should be done here, it's just mocked
+ # embeddings.load()
+ embeddings._model = MagicMock()
+
+ mocked_response = Mock(outputs=Mock(embedding=[0.1] * 10))
+ embeddings._model.encode = Mock(
+ side_effect=lambda x: [mocked_response for _ in range(len(x))]
+ )
+
+ results = embeddings.encode(
+ inputs=[
+ "Hello, how are you?",
+ "What a nice day!",
+ "I hear that llamas are very popular now.",
+ ]
+ )
+
+ for result in results:
+ assert len(result) == 10
diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py
index 0419cc36ce..d820122a4d 100644
--- a/tests/unit/llms/huggingface/test_inference_endpoints.py
+++ b/tests/unit/llms/huggingface/test_inference_endpoints.py
@@ -311,6 +311,9 @@ def test_serialization(self, mock_inference_client: MagicMock) -> None:
"structured_output": None,
"model_display_name": None,
"use_magpie_template": False,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.huggingface.inference_endpoints",
"name": "InferenceEndpointsLLM",
diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py
index 1d7fe44599..11fee764c3 100644
--- a/tests/unit/llms/test_anthropic.py
+++ b/tests/unit/llms/test_anthropic.py
@@ -163,6 +163,9 @@ def test_serialization(
"model": "claude-3-opus-20240229",
"timeout": 600.0,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.anthropic",
"name": "AnthropicLLM",
diff --git a/tests/unit/llms/test_anyscale.py b/tests/unit/llms/test_anyscale.py
index 73dd3cb6f7..178419c1b7 100644
--- a/tests/unit/llms/test_anyscale.py
+++ b/tests/unit/llms/test_anyscale.py
@@ -49,6 +49,9 @@ def test_serialization(self) -> None:
"base_url": "https://api.endpoints.anyscale.com/v1",
"timeout": 120,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.anyscale",
"name": "AnyscaleLLM",
diff --git a/tests/unit/llms/test_azure.py b/tests/unit/llms/test_azure.py
index 04b76d5545..eee3ed85fb 100644
--- a/tests/unit/llms/test_azure.py
+++ b/tests/unit/llms/test_azure.py
@@ -74,6 +74,9 @@ def test_azure_openai_llm_env_vars(self) -> None:
"base_url": "https://example-resource.azure.openai.com/",
"timeout": 120,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.azure",
"name": "AzureOpenAILLM",
@@ -98,6 +101,9 @@ def test_azure_openai_llm_env_vars(self) -> None:
"mode": "tool_call",
"max_retries": 1,
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.azure",
"name": "AzureOpenAILLM",
diff --git a/tests/unit/llms/test_base.py b/tests/unit/llms/test_base.py
new file mode 100644
index 0000000000..7c94227753
--- /dev/null
+++ b/tests/unit/llms/test_base.py
@@ -0,0 +1,28 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from distilabel.errors import DistilabelNotImplementedError
+from tests.unit.conftest import DummyLLM
+
+
+class TestLLM:
+ def test_offline_batch_generate_raise_distilabel_not_implemented_error(
+ self,
+ ) -> None:
+ llm = DummyLLM()
+
+ with pytest.raises(DistilabelNotImplementedError):
+ llm.offline_batch_generate()
diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py
index 371816edf6..2e398e01cf 100644
--- a/tests/unit/llms/test_cohere.py
+++ b/tests/unit/llms/test_cohere.py
@@ -141,6 +141,9 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
"timeout": 120,
"client_name": "distilabel",
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.cohere",
"name": "CohereLLM",
@@ -164,6 +167,9 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
"mode": "tool_call",
"max_retries": 1,
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.cohere",
"name": "CohereLLM",
diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py
index c8a782b9a8..f137750292 100644
--- a/tests/unit/llms/test_groq.py
+++ b/tests/unit/llms/test_groq.py
@@ -119,6 +119,9 @@ async def test_generate(self, mock_groq: MagicMock) -> None:
"max_retries": 2,
"timeout": 120,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.groq",
"name": "GroqLLM",
@@ -142,6 +145,9 @@ async def test_generate(self, mock_groq: MagicMock) -> None:
"mode": "tool_call",
"max_retries": 1,
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.groq",
"name": "GroqLLM",
diff --git a/tests/unit/llms/test_litellm.py b/tests/unit/llms/test_litellm.py
index f23722f2fa..56be99e028 100644
--- a/tests/unit/llms/test_litellm.py
+++ b/tests/unit/llms/test_litellm.py
@@ -83,6 +83,9 @@ def test_serialization(self, _: MagicMock, model: str) -> None:
"model": model,
"verbose": False,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.litellm",
"name": "LiteLLM",
diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/llms/test_llamacpp.py
index 280244964a..35c611722d 100644
--- a/tests/unit/llms/test_llamacpp.py
+++ b/tests/unit/llms/test_llamacpp.py
@@ -72,6 +72,9 @@ def test_generate(self, llm: LlamaCppLLM) -> None:
"seed": 4294967295,
"generation_kwargs": {},
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.llamacpp",
"name": "LlamaCppLLM",
@@ -96,6 +99,9 @@ def test_generate(self, llm: LlamaCppLLM) -> None:
"schema": DummyUserDetail.model_json_schema(),
"format": "json",
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.llamacpp",
"name": "LlamaCppLLM",
diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py
index 5bb2337481..f1b7b4b28f 100644
--- a/tests/unit/llms/test_mistral.py
+++ b/tests/unit/llms/test_mistral.py
@@ -97,7 +97,9 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
mocked_completion = Mock(
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
)
- llm._aclient.chat = AsyncMock(return_value=mocked_completion)
+ llm._aclient.chat = Mock(
+ complete_async=AsyncMock(return_value=mocked_completion)
+ )
nest_asyncio.apply()
@@ -126,6 +128,9 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
"timeout": 120,
"max_concurrent_requests": 64,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.mistral",
"name": "MistralLLM",
@@ -150,6 +155,9 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
"mode": "tool_call",
"max_retries": 1,
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.mistral",
"name": "MistralLLM",
@@ -172,6 +180,9 @@ def test_serialization(
"timeout": 120,
"max_concurrent_requests": 64,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.mistral",
"name": "MistralLLM",
diff --git a/tests/unit/llms/test_moa.py b/tests/unit/llms/test_moa.py
index 5863012d0e..7efd039b7a 100644
--- a/tests/unit/llms/test_moa.py
+++ b/tests/unit/llms/test_moa.py
@@ -13,22 +13,22 @@
# limitations under the License.
from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgentsLLM
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestMixtureOfAgents:
def test_model_name(self) -> None:
llm = MixtureOfAgentsLLM(
- aggregator_llm=DummyLLM(),
- proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ aggregator_llm=DummyAsyncLLM(),
+ proposers_llms=[DummyAsyncLLM(), DummyAsyncLLM(), DummyAsyncLLM()],
)
assert llm.model_name == "moa-test-test-test-test"
def test_build_moa_system_prompt(self) -> None:
llm = MixtureOfAgentsLLM(
- aggregator_llm=DummyLLM(),
- proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ aggregator_llm=DummyAsyncLLM(),
+ proposers_llms=[DummyAsyncLLM(), DummyAsyncLLM(), DummyAsyncLLM()],
)
system_prompt = llm._build_moa_system_prompt(
@@ -41,8 +41,8 @@ def test_build_moa_system_prompt(self) -> None:
def test_inject_moa_system_prompt(self) -> None:
llm = MixtureOfAgentsLLM(
- aggregator_llm=DummyLLM(),
- proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ aggregator_llm=DummyAsyncLLM(),
+ proposers_llms=[DummyAsyncLLM(), DummyAsyncLLM(), DummyAsyncLLM()],
)
results = llm._inject_moa_system_prompt(
diff --git a/tests/unit/llms/test_ollama.py b/tests/unit/llms/test_ollama.py
index f21006fa7b..db31d9cb07 100644
--- a/tests/unit/llms/test_ollama.py
+++ b/tests/unit/llms/test_ollama.py
@@ -82,6 +82,9 @@ def test_serialization(self, _: MagicMock) -> None:
"follow_redirects": True,
"generation_kwargs": {},
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.ollama",
"name": "OllamaLLM",
diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py
index a1b09b0883..03fb94c1d3 100644
--- a/tests/unit/llms/test_openai.py
+++ b/tests/unit/llms/test_openai.py
@@ -14,29 +14,38 @@
import os
import sys
+from textwrap import dedent
from typing import Any, Dict
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import nest_asyncio
+import orjson
import pytest
+from openai.types import Batch
-from distilabel.llms.openai import OpenAILLM
+from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
+from distilabel.llms.openai import _OPENAI_BATCH_API_MAX_FILE_SIZE, OpenAILLM
from .utils import DummyUserDetail
+@patch("openai.OpenAI")
@patch("openai.AsyncOpenAI")
class TestOpenAILLM:
model_id: str = "gpt-4"
- def test_openai_llm(self, _: MagicMock) -> None:
+ def test_openai_llm(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
assert isinstance(llm, OpenAILLM)
assert llm.model_name == self.model_id
- def test_openai_llm_env_vars(self, _: MagicMock) -> None:
+ def test_openai_llm_env_vars(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
with mock.patch.dict(os.environ, clear=True):
os.environ["OPENAI_API_KEY"] = "another.api.key"
os.environ["OPENAI_BASE_URL"] = "https://example.com"
@@ -49,9 +58,11 @@ def test_openai_llm_env_vars(self, _: MagicMock) -> None:
assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore
@pytest.mark.asyncio
- async def test_agenerate(self, mock_openai: MagicMock) -> None:
+ async def test_agenerate(
+ self, async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
- llm._aclient = mock_openai
+ llm._aclient = async_openai_mock
mocked_completion = Mock(
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
@@ -69,7 +80,9 @@ async def test_agenerate(self, mock_openai: MagicMock) -> None:
)
@pytest.mark.asyncio
- async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
+ async def test_agenerate_structured(
+ self, async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
llm = OpenAILLM(
model=self.model_id,
api_key="api.key",
@@ -79,7 +92,7 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
"max_retries": 1,
},
) # type: ignore
- llm._aclient = mock_openai
+ llm._aclient = async_openai_mock
sample_user = DummyUserDetail(name="John Doe", age=30)
@@ -100,9 +113,11 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
)
@pytest.mark.asyncio
- async def test_generate(self, mock_openai: MagicMock) -> None:
+ async def test_generate(
+ self, async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
- llm._aclient = mock_openai
+ llm._aclient = async_openai_mock
mocked_completion = Mock(
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
@@ -137,6 +152,299 @@ async def test_generate(self, mock_openai: MagicMock) -> None:
response_format="unkown_format",
)
+ def test_offline_batch_generate(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+ llm._create_jobs = mock.MagicMock(return_value=("1234", "5678"))
+
+ with pytest.raises(
+ DistilabelOfflineBatchGenerationNotFinishedException
+ ) as exception_info:
+ llm.offline_batch_generate(
+ inputs=[{"role": "user", "content": "How much is 2+2?"}] # type: ignore
+ )
+
+ assert exception_info.value.jobs_ids == ("1234", "5678")
+
+ def test_offline_batch_generate_with_job_ids(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key", jobs_ids=("1234",)) # type: ignore
+ llm._check_and_get_batch_results = mock.MagicMock(
+ return_value=[
+ ["output 1"],
+ ["output 2"],
+ ]
+ )
+ assert llm.offline_batch_generate() == [["output 1"], ["output 2"]]
+
+ def test_check_and_get_batch_results(
+ self, async_openai_mock: MagicMock, openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key", jobs_ids=("1234",)) # type: ignore
+ llm._aclient = async_openai_mock
+ llm._client = openai_mock
+ llm._retrieve_batch_results = mock.MagicMock(
+ return_value=[
+ {
+ "custom_id": 2,
+ "response": {
+ "status_code": 200,
+ "body": {
+ "id": "1234",
+ "created": 13,
+ "model": "gpt-4",
+ "object": "chat.completion",
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "output 2",
+ },
+ }
+ ],
+ },
+ },
+ },
+ {
+ "custom_id": 1,
+ "response": {
+ "status_code": 200,
+ "body": {
+ "id": "1234",
+ "created": 13,
+ "model": "gpt-4",
+ "object": "chat.completion",
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "output 1",
+ },
+ }
+ ],
+ },
+ },
+ },
+ ]
+ )
+ llm.load()
+
+ outputs = llm._check_and_get_batch_results()
+ assert outputs == [["output 1"], ["output 2"]]
+
+ def test_check_and_get_batch_results_raises_valueerror(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+
+ with pytest.raises(ValueError, match="No job IDs were found"):
+ llm._check_and_get_batch_results()
+
+ @pytest.mark.parametrize("status", ("validating", "in_progress", "finalizing"))
+ def test_check_and_get_batch_results_raises_distilabel_exception(
+ self, async_openai_mock: MagicMock, openai_mock: MagicMock, status: str
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key", jobs_ids=("1234",)) # type: ignore
+ llm._aclient = async_openai_mock
+ llm._client = openai_mock
+ llm._get_openai_batch = mock.MagicMock(
+ return_value=Batch(
+ id="1234",
+ completion_window="24h",
+ created_at=13,
+ endpoint="/v1/chat/completions",
+ input_file_id="1234",
+ object="batch",
+ status=status, # type: ignore
+ output_file_id="1234",
+ )
+ )
+ llm.load()
+
+ with pytest.raises(DistilabelOfflineBatchGenerationNotFinishedException):
+ llm._check_and_get_batch_results()
+
+ @pytest.mark.parametrize("status", ("failed", "expired", "cancelled", "cancelling"))
+ def test_check_and_get_batch_results_raises_runtimeerror(
+ self, async_openai_mock: MagicMock, openai_mock: MagicMock, status: str
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key", jobs_ids=("1234",)) # type: ignore
+ llm._aclient = async_openai_mock
+ llm._client = openai_mock
+ llm._get_openai_batch = mock.MagicMock(
+ return_value=Batch(
+ id="1234",
+ completion_window="24h",
+ created_at=13,
+ endpoint="/v1/chat/completions",
+ input_file_id="1234",
+ object="batch",
+ status=status, # type: ignore
+ output_file_id="1234",
+ )
+ )
+ llm.load()
+
+ with pytest.raises(
+ RuntimeError,
+ match=f"The only OpenAI API Batch that was created with ID '1234' failed with status '{status}",
+ ):
+ llm._check_and_get_batch_results()
+
+ def test_parse_output(
+ self, _async_openai_mock: MagicMock, openai_mock: MagicMock
+ ) -> None:
+ pass
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+
+ result = llm._parse_output(
+ {
+ "response": {
+ "status_code": 200,
+ "body": {
+ "id": "1234",
+ "created": 13,
+ "model": "gpt-4",
+ "object": "chat.completion",
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": " Aenean hendrerit aliquam velit. ...",
+ },
+ }
+ ],
+ },
+ }
+ }
+ )
+
+ assert result == [" Aenean hendrerit aliquam velit. ..."]
+
+ def test_retrieve_batch_results(
+ self, _async_openai_mock: MagicMock, openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+ llm._client = openai_mock
+
+ class Response:
+ text: str = dedent(
+ """
+ {"response": {"status_code": 200, "body": {}}}
+ {"response": {"status_code": 200, "body": {}}}
+ {"response": {"status_code": 200, "body": {}}}
+ """.lstrip()
+ )
+
+ llm._client.files.content.return_value = Response()
+
+ results = llm._retrieve_batch_results(
+ batch=Batch(
+ id="1234",
+ completion_window="24h",
+ created_at=13,
+ endpoint="/v1/chat/completions",
+ input_file_id="1234",
+ object="batch",
+ status="completed",
+ output_file_id="1234",
+ )
+ ) # type: ignore
+ assert len(results) == 3
+
+ def test_create_jobs(
+ self, _async_openai_mock: MagicMock, openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+ llm._client = openai_mock
+
+ messages = [
+ {
+ "role": "user",
+ "content": "x" * ((_OPENAI_BATCH_API_MAX_FILE_SIZE // 100) - 50),
+ }
+ ]
+ inputs = [messages] * 150
+
+ jobs = llm._create_jobs(inputs=inputs) # type: ignore
+ assert isinstance(jobs, tuple)
+ assert len(jobs) == 2
+
+ def test_create_batch_files(
+ self, _async_openai_mock: MagicMock, openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+ llm._client = openai_mock
+
+ messages = [
+ {
+ "role": "user",
+ "content": "x" * ((_OPENAI_BATCH_API_MAX_FILE_SIZE // 100) - 50),
+ }
+ ]
+ inputs = [messages] * 150
+
+ files = llm._create_batch_files(inputs=inputs) # type: ignore
+ assert len(files) == 2
+
+ def test_create_jsonl_buffers(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+
+ # This should be around 1MB
+ messages = [
+ {
+ "role": "user",
+ "content": "x" * ((_OPENAI_BATCH_API_MAX_FILE_SIZE // 100) - 50),
+ }
+ ]
+
+ # Create an input that is larger than the max file size (150MB)
+ inputs = [messages] * 150
+ output = list(llm._create_jsonl_buffers(inputs=inputs)) # type: ignore
+ assert len(output) == 2
+
+ def test_create_jsonl_row(
+ self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
+ output = llm._create_jsonl_row(
+ input=[{"role": "user", "content": "How much is 2+2?"}],
+ custom_id="unit-test",
+ **{
+ "model": "gpt-4",
+ "temperature": 0.8,
+ "max_new_tokens": 512,
+ },
+ )
+
+ assert isinstance(output, bytes)
+ assert orjson.loads(output.decode("utf-8")) == {
+ "custom_id": "unit-test",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "messages": [
+ {
+ "role": "user",
+ "content": "How much is 2+2?",
+ }
+ ],
+ "model": "gpt-4",
+ "temperature": 0.8,
+ "max_new_tokens": 512,
+ },
+ }
+
@pytest.mark.parametrize(
"structured_output, dump",
[
@@ -149,6 +457,9 @@ async def test_generate(self, mock_openai: MagicMock) -> None:
"base_url": "https://api.openai.com/v1",
"timeout": 120,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.openai",
"name": "OpenAILLM",
@@ -172,6 +483,9 @@ async def test_generate(self, mock_openai: MagicMock) -> None:
"mode": "tool_call",
"max_retries": 1,
},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.openai",
"name": "OpenAILLM",
@@ -181,7 +495,11 @@ async def test_generate(self, mock_openai: MagicMock) -> None:
],
)
def test_serialization(
- self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ self,
+ _async_openai_mock: MagicMock,
+ _openai_mock: MagicMock,
+ structured_output: Dict[str, Any],
+ dump: Dict[str, Any],
) -> None:
llm = OpenAILLM(model=self.model_id, structured_output=structured_output)
diff --git a/tests/unit/llms/test_together.py b/tests/unit/llms/test_together.py
index d9b50b02d0..409f34866f 100644
--- a/tests/unit/llms/test_together.py
+++ b/tests/unit/llms/test_together.py
@@ -49,6 +49,9 @@ def test_serialization(self) -> None:
"base_url": "https://api.together.xyz/v1",
"timeout": 120,
"structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.together",
"name": "TogetherLLM",
diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/llms/test_vertexai.py
index 9ad575fb0a..38f5933849 100644
--- a/tests/unit/llms/test_vertexai.py
+++ b/tests/unit/llms/test_vertexai.py
@@ -116,6 +116,9 @@ def test_serialization(self, _: MagicMock) -> None:
_dump = {
"model": "gemini-1.0-pro",
"generation_kwargs": {},
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "distilabel.llms.vertexai",
"name": "VertexAILLM",
diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py
index c4ba3ddc62..c1df505126 100644
--- a/tests/unit/llms/test_vllm.py
+++ b/tests/unit/llms/test_vllm.py
@@ -194,15 +194,14 @@ def test_prepare_batches_and_sort_back(
@mock.patch("openai.AsyncOpenAI")
class TestClientvLLM:
def test_clientvllm_model_name(
- self, _openai_mock: mock.MagicMock, _async_openai_mock: mock.MagicMock
+ self, _: mock.MagicMock, openai_mock: mock.MagicMock
) -> None:
llm = ClientvLLM(
base_url="http://localhost:8000/v1",
tokenizer="google-bert/bert-base-uncased",
)
- llm.load()
-
+ llm._client = mock.MagicMock()
llm._client.models.list.return_value = SyncPage[Model]( # type: ignore
data=[Model(id="llama", created=1234, object="model", owned_by="")],
object="model",
diff --git a/tests/unit/pipeline/conftest.py b/tests/unit/pipeline/conftest.py
index b3e708a178..a2bf2b932d 100644
--- a/tests/unit/pipeline/conftest.py
+++ b/tests/unit/pipeline/conftest.py
@@ -14,7 +14,10 @@
import pytest
+from distilabel.pipeline._dag import DAG
+from distilabel.pipeline.batch_manager import _BatchManager
from distilabel.pipeline.local import Pipeline
+from distilabel.steps.base import GeneratorStep, GlobalStep, Step
from .utils import DummyGeneratorStep, DummyGlobalStep, DummyStep1, DummyStep2
@@ -42,3 +45,26 @@ def dummy_generator_step_fixture(pipeline: "Pipeline") -> DummyGeneratorStep:
@pytest.fixture(name="dummy_global_step")
def dummy_global_step_fixture(pipeline: "Pipeline") -> DummyGlobalStep:
return DummyGlobalStep(name="dummy_global_step", pipeline=pipeline)
+
+
+@pytest.fixture(name="dummy_dag")
+def dummy_dag_fixture(
+ dummy_generator_step: "GeneratorStep",
+ dummy_step_1: "Step",
+ dummy_step_2: "Step",
+ dummy_global_step: "GlobalStep",
+) -> DAG:
+ dag = DAG()
+ dag.add_step(dummy_generator_step)
+ dag.add_step(dummy_step_1)
+ dag.add_step(dummy_step_2)
+ dag.add_step(dummy_global_step)
+ dag.add_edge("dummy_generator_step", "dummy_step_1")
+ dag.add_edge("dummy_generator_step", "dummy_global_step")
+ dag.add_edge("dummy_step_1", "dummy_step_2")
+ return dag
+
+
+@pytest.fixture(name="dummy_batch_manager")
+def dummy_batch_manager_from_dag_fixture(dummy_dag: DAG) -> _BatchManager:
+ return _BatchManager.from_dag(dummy_dag)
diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py
index 3c0adbe271..86db3f5cfb 100644
--- a/tests/unit/pipeline/test_base.py
+++ b/tests/unit/pipeline/test_base.py
@@ -25,6 +25,11 @@
from pydantic import Field
from upath import UPath
+from distilabel.constants import (
+ INPUT_QUEUE_ATTR_NAME,
+ LAST_BATCH_SENT_FLAG,
+ STEPS_ARTIFACTS_PATH,
+)
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.pipeline.base import (
_STEP_LOAD_FAILED_CODE,
@@ -34,7 +39,6 @@
)
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.batch_manager import _BatchManager
-from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME, LAST_BATCH_SENT_FLAG
from distilabel.pipeline.routing_batch_function import (
routing_batch_function,
sample_n_steps,
@@ -91,6 +95,28 @@ def test_get_pipeline(self) -> None:
class TestBasePipeline:
+ def test_aggregated_steps_signature(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline_0:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ with DummyPipeline(name="dummy") as pipeline_1:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ assert (
+ pipeline_0.aggregated_steps_signature
+ == pipeline_1.aggregated_steps_signature
+ )
+
def test_context_manager(self) -> None:
assert _GlobalPipelineManager.get_pipeline() is None
@@ -102,28 +128,35 @@ def test_context_manager(self) -> None:
@pytest.mark.parametrize("use_cache", [False, True])
def test_load_batch_manager(self, use_cache: bool) -> None:
- pipeline = DummyPipeline(name="unit-test-pipeline")
- pipeline._load_batch_manager(use_cache=True)
- pipeline._cache()
-
- with (
- mock.patch(
- "distilabel.pipeline.base._BatchManager.load_from_cache"
- ) as mock_load_from_cache,
- mock.patch(
- "distilabel.pipeline.base._BatchManager.from_dag"
- ) as mock_from_dag,
- ):
- pipeline._load_batch_manager(use_cache=use_cache)
-
- if use_cache:
- mock_load_from_cache.assert_called_once_with(
- pipeline._cache_location["batch_manager"]
- )
- mock_from_dag.assert_not_called()
- else:
- mock_load_from_cache.assert_not_called()
- mock_from_dag.assert_called_once_with(pipeline.dag)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ pipeline = DummyPipeline(name="unit-test-pipeline", cache_dir=temp_dir)
+ pipeline._load_batch_manager(use_cache=True)
+ pipeline._cache()
+
+ with (
+ mock.patch(
+ "distilabel.pipeline.base._BatchManager.load_from_cache"
+ ) as mock_load_from_cache,
+ mock.patch(
+ "distilabel.pipeline.base._BatchManager.from_dag"
+ ) as mock_from_dag,
+ ):
+ pipeline._load_batch_manager(use_cache=use_cache)
+
+ if use_cache:
+ mock_load_from_cache.assert_called_once_with(
+ dag=pipeline.dag,
+ batch_manager_path=pipeline._cache_location["batch_manager"],
+ steps_data_path=pipeline._cache_location["steps_data"],
+ )
+ mock_from_dag.assert_not_called()
+ else:
+ mock_load_from_cache.assert_not_called()
+ mock_from_dag.assert_called_once_with(
+ dag=pipeline.dag,
+ use_cache=use_cache,
+ steps_data_path=pipeline._cache_location["steps_data"],
+ )
def test_setup_write_buffer(self) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
@@ -155,6 +188,23 @@ def test_setup_fsspec_raises_value_error(self) -> None:
with pytest.raises(ValueError, match="The 'path' key must be present"):
pipeline._setup_fsspec({"key": "random"})
+ def test_set_pipeline_artifacts_path_in_steps(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ pipeline._set_pipeline_artifacts_path_in_steps()
+
+ artifacts_directory = pipeline._cache_location["data"] / STEPS_ARTIFACTS_PATH
+ assert generator.artifacts_directory == artifacts_directory / generator.name # type: ignore
+ assert step.artifacts_directory == artifacts_directory / step.name # type: ignore
+ assert step2.artifacts_directory == artifacts_directory / step2.name # type: ignore
+ assert step3.artifacts_directory == artifacts_directory / step3.name # type: ignore
+
def test_init_steps_load_status(self) -> None:
with DummyPipeline(name="dummy") as pipeline:
generator = DummyGeneratorStep()
@@ -208,6 +258,15 @@ def test_should_continue_processing(self) -> None:
assert not pipeline._should_continue_processing()
+ def test_set_step_for_recovering_offline_batch_generation(self) -> None:
+ with DummyPipeline() as pipeline:
+ step = DummyStep1()
+
+ data = [[{"a": 0}, {"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]]
+ pipeline._set_step_for_recovering_offline_batch_generation(step=step, data=data)
+
+ assert pipeline._recover_offline_batch_generate_for_step == (step.name, data)
+
def test_should_load_next_stage(self) -> None:
with DummyPipeline(name="dummy") as pipeline:
generator = DummyGeneratorStep()
@@ -297,6 +356,7 @@ def test_run_stage_steps_and_wait(self, caplog) -> None:
generator >> [step, step2] >> step3 >> step4
+ pipeline._load_batch_manager()
pipeline._steps_load_status = { # type: ignore
generator.name: 1,
step.name: 1,
@@ -320,6 +380,7 @@ def test_run_stage_steps_and_wait_with_failing_step(self, caplog) -> None:
generator >> [step, step2] >> step3 >> step4
pipeline._init_steps_load_status()
+ pipeline._load_batch_manager()
pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore
caplog.set_level(logging.INFO)
@@ -337,6 +398,7 @@ def test_run_stage_steps_and_wait_stop_called(self) -> None:
generator >> [step, step2] >> step3 >> step4
pipeline._init_steps_load_status()
+ pipeline._load_batch_manager()
pipeline._stop_called = True
assert pipeline._run_stage_steps_and_wait(stage=0) is False
@@ -353,6 +415,7 @@ def test_handle_stop(self) -> None:
pipeline._add_batches_back_to_batch_manager = mock.MagicMock()
pipeline._wait_step_input_queue_empty = mock.MagicMock()
pipeline._consume_output_queue = mock.MagicMock()
+ pipeline._stages_last_batch = [[]]
pipeline._handle_stop()
@@ -594,7 +657,9 @@ def test_register_batch(self) -> None:
batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
pipeline._register_batch(batch)
- pipeline._batch_manager.register_batch.assert_called_once_with(batch)
+ pipeline._batch_manager.register_batch.assert_called_once_with(
+ batch, steps_data_path=pipeline._cache_location["steps_data"]
+ )
def test_send_last_batch_flag_to_step(self) -> None:
with DummyPipeline(name="unit-test-pipeline") as pipeline:
@@ -711,7 +776,9 @@ def test_handle_batch_on_stop(self) -> None:
batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
pipeline._handle_batch_on_stop(batch)
- batch_manager_mock.register_batch.assert_called_once_with(batch)
+ batch_manager_mock.register_batch.assert_called_once_with(
+ batch, steps_data_path=pipeline._cache_location["steps_data"]
+ )
batch_manager_mock.add_batch.assert_has_calls(
[
mock.call(step.name, batch),
@@ -1166,13 +1233,16 @@ def test_pipeline_with_dataset_and_generator_step(self):
)
def test_optional_name(self):
- import random
+ from distilabel.pipeline.base import _PIPELINE_DEFAULT_NAME
+
+ assert DummyPipeline().name == _PIPELINE_DEFAULT_NAME
- random.seed(42)
with DummyPipeline() as pipeline:
- name = pipeline.name
- assert name.startswith("pipeline")
- assert len(name.split("_")[-1]) == 8
+ gen_step = DummyGeneratorStep()
+ step1_0 = DummyStep1()
+ gen_step >> step1_0
+
+ assert pipeline.name == "pipeline_dummy_generator_step_0_dummy_step1_0"
class TestPipelineSerialization:
@@ -1265,8 +1335,7 @@ def test_base_pipeline_signature(self) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
# Doesn't matter if it's exactly this or not, the test should fail if we change the
# way this is created.
- signature = pipeline._create_signature()
- assert signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709"
+ assert pipeline.signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709"
# Maybe not the best place for this test, but does the work for now
from distilabel.pipeline.local import Pipeline
@@ -1276,11 +1345,11 @@ def test_base_pipeline_signature(self) -> None:
sample_two_steps = sample_n_steps(2)
with Pipeline(name="unit-test-pipeline") as pipeline:
- dummy_generator = DummyGeneratorStep()
- dummy_step_1_0 = DummyStep1()
- dummy_step_1_1 = DummyStep1()
- dummy_step_1_2 = DummyStep1()
- dummy_step_2 = DummyStep2()
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0")
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
(
dummy_generator
@@ -1289,8 +1358,68 @@ def test_base_pipeline_signature(self) -> None:
>> dummy_step_2
)
- signature = pipeline._create_signature()
- assert signature == "d3c7c572fe31233aa1198174c6c793b67ef3744b"
+ assert pipeline.signature == "edff8f5bb8b51da406ff274e640f87264f014e3b"
+
+ # attributes shouldn't affect in pipeline signature
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0", attr1=17238497128934)
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+
+ (
+ dummy_generator
+ >> sample_two_steps
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "edff8f5bb8b51da406ff274e640f87264f014e3b"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0")
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+
+ (
+ dummy_generator
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "5634172be496319d50848b1679b2a8781cc5581f"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator_second_time")
+ dummy_step_1_0 = DummyStep1(
+ name="dummy_step_1_0_second_time", attr1=17238497128934
+ )
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1_second_time")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2_second_time")
+ dummy_step_2 = DummyStep2(name="dummy_step_2_second_time")
+
+ (
+ dummy_generator
+ >> sample_two_steps
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "806dad3fca0f8274af0f374660d4e3eb25d62d12"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator_second_time")
+ dummy_step_1_0 = DummyStep1(
+ name="dummy_step_1_0_second_time", attr1=17238497128934
+ )
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1_second_time")
+
+ (dummy_generator >> sample_two_steps >> [dummy_step_1_0, dummy_step_1_1])
+
+ assert pipeline.signature == "7222ce34c677bea3720ef3d08c2673b29b61ff9b"
def test_binary_rshift_operator(self) -> None:
# Tests the steps can be connected using the >> operator.
@@ -1305,7 +1434,7 @@ def test_binary_rshift_operator(self) -> None:
dummy_generator.connect(dummy_step_1)
dummy_step_1.connect(dummy_step_2)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-3") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1314,7 +1443,7 @@ def test_binary_rshift_operator(self) -> None:
dummy_generator >> dummy_step_1 >> dummy_step_2
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1331,7 +1460,7 @@ def test_binary_rshift_operator_with_list(self) -> None:
dummy_generator.connect(dummy_step_1)
dummy_generator.connect(dummy_step_2)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1340,7 +1469,7 @@ def test_binary_rshift_operator_with_list(self) -> None:
dummy_generator >> [dummy_step_1, dummy_step_2]
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1360,7 +1489,7 @@ def test_binary_rrshift_operator(self) -> None:
dummy_step_1.connect(dummy_global)
dummy_step_2.connect(dummy_global)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_step_1 = DummyStep1(name="dummy_step_1")
@@ -1368,7 +1497,7 @@ def test_binary_rrshift_operator(self) -> None:
dummy_global = DummyGlobalStep(name="dummy_global_step")
[dummy_step_1, dummy_step_2] >> dummy_global
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1394,7 +1523,7 @@ def test_binary_operators(self) -> None:
dummy_step_1.connect(dummy_global)
dummy_step_2.connect(dummy_global)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1403,6 +1532,6 @@ def test_binary_operators(self) -> None:
dummy_global = DummyGlobalStep(name="dummy_global_step")
dummy_generator >> [dummy_step_1, dummy_step_2] >> dummy_global
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
diff --git a/tests/unit/pipeline/test_batch_manager.py b/tests/unit/pipeline/test_batch_manager.py
index e0e7547305..8801096ce8 100644
--- a/tests/unit/pipeline/test_batch_manager.py
+++ b/tests/unit/pipeline/test_batch_manager.py
@@ -15,14 +15,18 @@
import tempfile
from pathlib import Path
from typing import Dict, List
+from unittest import mock
import pytest
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.batch_manager import _BatchManager, _BatchManagerStep
+from distilabel.pipeline.local import Pipeline
from distilabel.steps.base import GeneratorStep, GlobalStep, Step
+from .utils import DummyGeneratorStep, DummyStep1, DummyStep2
+
class TestBatchManagerStep:
def test_add_batch(self) -> None:
@@ -144,6 +148,7 @@ def test_get_batch(self) -> None:
)
],
},
+ step_offset={"step1": (0, 0), "step2": (0, 0)},
built_batches=[previously_built_batch],
next_expected_seq_no={"step1": (1, 1), "step2": (1, 1)},
)
@@ -168,7 +173,7 @@ def test_get_batch(self) -> None:
{"b": 2},
],
],
- created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ created_from={"step1": [(1, 5, 2)], "step2": [(1, 5, 2)]},
)
batch = batch_manager_step.get_batch()
@@ -187,7 +192,7 @@ def test_get_batch(self) -> None:
{"b": 4},
],
],
- created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ created_from={"step1": [(1, 5, 2)], "step2": [(1, 5, 2)]},
)
def test_get_batches_accumulate(self) -> None:
@@ -231,6 +236,7 @@ def test_get_batches_accumulate(self) -> None:
)
],
},
+ step_offset={"step1": (0, 0), "step2": (0, 0)},
last_batch_received=["step1", "step2"],
)
@@ -258,7 +264,7 @@ def test_get_batches_accumulate(self) -> None:
{"b": 6},
],
],
- created_from={"step1": [(0, 5)], "step2": [(0, 6)]},
+ created_from={"step1": [(0, 5, 5)], "step2": [(0, 6, 6)]},
)
def test_get_batches_not_enough_data(self) -> None:
@@ -430,7 +436,7 @@ def test_get_data(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}],
]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert created_from == {"step1": [(0, 6, 5)], "step2": [(0, 7, 5)]}
assert routed_to == ["step1", "step2"]
assert batch_manager_step.data == {
@@ -502,7 +508,7 @@ def test_get_data_accumulate(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}],
]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert created_from == {"step1": [(0, 6, 6)], "step2": [(0, 7, 7)]}
assert routed_to == []
assert batch_manager_step.data == {"step1": [], "step2": []}
@@ -520,7 +526,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(0, 3)]},
+ created_from={"Z": [(0, 3, 3)]},
)
batch_a_1 = _Batch(
@@ -535,7 +541,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(1, 3)]},
+ created_from={"Z": [(1, 3, 3)]},
)
batch_b_0 = _Batch(
@@ -550,7 +556,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(0, 3)]},
+ created_from={"Z": [(0, 3, 3)]},
)
batch_c_0 = _Batch(
@@ -565,7 +571,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(1, 3)]},
+ created_from={"Z": [(1, 3, 3)]},
)
batch_manager_step = _BatchManagerStep(
@@ -590,7 +596,7 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm B 0"},
],
]
- assert created_from == {"A": [(0, 3)], "B": [(0, 3)]}
+ assert created_from == {"A": [(0, 3, 3)], "B": [(0, 3, 3)]}
assert routed_to == []
assert batch_manager_step.next_expected_created_from_batch_seq_no == 1
@@ -608,7 +614,7 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm C 0"},
],
]
- assert created_from == {"A": [(1, 3)], "C": [(0, 3)]}
+ assert created_from == {"A": [(1, 3, 3)], "C": [(0, 3, 3)]}
assert routed_to == []
assert batch_manager_step.next_expected_created_from_batch_seq_no == 2
@@ -803,7 +809,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -812,7 +818,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -827,7 +833,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -836,7 +842,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -851,7 +857,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- created_from={"step0": [(0, 3)]},
+ created_from={"step0": [(0, 3, 3)]},
)
],
"step2": [
@@ -860,7 +866,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
- created_from={"step0": [(0, 3)]},
+ created_from={"step0": [(0, 3, 3)]},
)
],
},
@@ -1217,6 +1223,9 @@ def test_dump(self) -> None:
"step1": (0, 0),
"step2": (0, 0),
},
+ "step_offset": {},
+ "step_signature": None,
+ "use_cache": False,
"type_info": {
"module": "distilabel.pipeline.batch_manager",
"name": "_BatchManagerStep",
@@ -1235,7 +1244,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [],
@@ -1252,7 +1261,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -1262,7 +1271,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1278,7 +1287,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1288,7 +1297,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1304,7 +1313,7 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1314,7 +1323,7 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
},
@@ -1330,7 +1339,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1340,7 +1349,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1467,6 +1476,41 @@ def test_add_batch(self) -> None:
"step2": [],
}
+ def test_step_hash_finished(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step1": _BatchManagerStep(
+ step_name="step1",
+ accumulate=False,
+ input_batch_size=5,
+ data={},
+ ),
+ "step2": _BatchManagerStep(
+ step_name="step2",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step_1": []},
+ ),
+ "step3": _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step2": []},
+ ),
+ },
+ last_batch_received={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True),
+ "step2": None,
+ "step3": None,
+ },
+ last_batch_sent={"step1": None, "step2": None, "step3": None},
+ last_batch_flag_sent_to=["step2"],
+ )
+
+ assert batch_manager.step_has_finished("step1") is True
+ assert batch_manager.step_has_finished("step2") is True
+ assert batch_manager.step_has_finished("step3") is False
+
def test_add_batch_with_prepend(self) -> None:
batch_1 = _Batch(
seq_no=1,
@@ -1503,6 +1547,38 @@ def test_add_batch_with_prepend(self) -> None:
"step2": [],
}
+ def test_add_batch_to_recover_offline_batch_generation(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step1": _BatchManagerStep(
+ step_name="step0",
+ accumulate=True,
+ input_batch_size=5,
+ data={},
+ )
+ },
+ last_batch_received={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True)
+ },
+ last_batch_sent={"step1": None},
+ last_batch_flag_sent_to=[],
+ )
+
+ batch_manager.add_batch_to_recover_offline_batch_generation(
+ to_step="step1",
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+
+ assert batch_manager._steps["step1"].built_batches == [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ]
+ assert batch_manager._last_batch_received["step1"] is None
+
def test_from_dag(
self,
dummy_generator_step: "GeneratorStep",
@@ -1522,12 +1598,26 @@ def test_from_dag(
batch_manager = _BatchManager.from_dag(dag)
assert batch_manager._steps == {
+ "dummy_generator_step": _BatchManagerStep(
+ step_name="dummy_generator_step",
+ accumulate=False,
+ input_batch_size=None,
+ data={},
+ convergence_step=True,
+ next_expected_seq_no={},
+ step_signature="963a16b6081170f39eef011d64d992a0a6e9f0e9",
+ use_cache=True,
+ step_offset={},
+ ),
"dummy_step_1": _BatchManagerStep(
step_name="dummy_step_1",
accumulate=False,
input_batch_size=50,
data={"dummy_generator_step": []},
next_expected_seq_no={"dummy_generator_step": (0, 0)},
+ step_signature="bc765d5801dc71c88a1a444e1b1e26035d309724",
+ use_cache=True,
+ step_offset={"dummy_generator_step": (0, 0)},
),
"dummy_global_step": _BatchManagerStep(
step_name="dummy_global_step",
@@ -1535,6 +1625,9 @@ def test_from_dag(
input_batch_size=50,
data={"dummy_generator_step": []},
next_expected_seq_no={"dummy_generator_step": (0, 0)},
+ step_signature="6a0e9f45043fa7dc37e2b36269d660dfef63dbb7",
+ use_cache=True,
+ step_offset={"dummy_generator_step": (0, 0)},
),
"dummy_step_2": _BatchManagerStep(
step_name="dummy_step_2",
@@ -1542,9 +1635,73 @@ def test_from_dag(
input_batch_size=50,
data={"dummy_step_1": []},
next_expected_seq_no={"dummy_step_1": (0, 0)},
+ step_signature="2d1076164acb43431aad1a54a781b7bad22c7037",
+ use_cache=True,
+ step_offset={"dummy_step_1": (0, 0)},
),
}
+ def test_cache(self, dummy_batch_manager: _BatchManager) -> None:
+ # We test the cache starting from the DAG because we need the signature
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ dummy_batch_manager.cache(batch_manager_path, Path(tmp_dir))
+
+ assert batch_manager_path.exists() and batch_manager_path.is_file()
+
+ for step_name, step in dummy_batch_manager._steps.items():
+ batch_manager_step_dir = (
+ Path(tmp_dir) / "batch_manager_steps" / step_name
+ )
+ assert (
+ batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir()
+ )
+
+ batch_manager_step_path = (
+ batch_manager_step_dir / "batch_manager_step.json"
+ )
+ assert (
+ batch_manager_step_path.exists()
+ and batch_manager_step_path.is_file()
+ )
+
+ built_batches_dir = batch_manager_step_dir / "built_batches"
+ assert built_batches_dir.exists()
+
+ for batch in step.built_batches:
+ batch_path = (
+ built_batches_dir
+ / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ )
+ assert batch_path.exists() and batch_path.is_file()
+
+ # for buffered_step_name in step.data:
+ # buffered_step_dir = batch_manager_step_dir / buffered_step_name
+ # assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
+
+ # for batch in step.data[buffered_step_name]:
+ # batch_path = (
+ # buffered_step_dir
+ # / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ # )
+ # assert batch_path.exists() and batch_path.is_file()
+
+ def test_load_from_cache(
+ self, dummy_dag: DAG, dummy_batch_manager: _BatchManager
+ ) -> None:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ from pathlib import Path
+
+ tmp_dir = Path.home() / "Downloads/test_batch_manager"
+
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ dummy_batch_manager.cache(batch_manager_path, Path(tmp_dir))
+ loaded_batch_manager = _BatchManager.load_from_cache(
+ dummy_dag, batch_manager_path, Path(tmp_dir)
+ )
+
+ assert dummy_batch_manager.dump() == loaded_batch_manager.dump()
+
def test_can_generate(self) -> None:
batch_manager = _BatchManager(
steps={},
@@ -1576,6 +1733,108 @@ def test_can_generate(self) -> None:
assert not batch_manager.can_generate()
+ def test_invalidate_cache_for(self) -> None:
+ with Pipeline() as pipeline:
+ generator = DummyGeneratorStep()
+ step_a = DummyStep1()
+ step_b = DummyStep1()
+ step_c = DummyStep2()
+
+ generator >> [step_a, step_b] >> step_c
+
+ pipeline._load_batch_manager()
+ batch_manager: "_BatchManager" = pipeline._batch_manager # type: ignore
+
+ with (
+ mock.patch.object(
+ batch_manager, "_reset_batch_manager_for_step"
+ ) as reset_mock,
+ mock.patch.object(batch_manager, "_load_predecessor_batches") as load_mock,
+ ):
+ batch_manager.invalidate_cache_for(
+ step_name=step_a.name, # type: ignore
+ dag=pipeline.dag,
+ steps_data_path=pipeline._cache_location["steps_data"],
+ )
+
+ # shouldn't have been called for step b
+ reset_mock.assert_has_calls(
+ [
+ mock.call(step_a.name, pipeline.dag),
+ mock.call(step_c.name, pipeline.dag),
+ ]
+ )
+
+ load_mock.assert_called_once_with(
+ step_a.name, pipeline.dag, pipeline._cache_location["steps_data"]
+ )
+
+ def test_reset_batch_manager_for_step(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step1": _BatchManagerStep(
+ step_name="step1",
+ accumulate=True,
+ input_batch_size=5,
+ data={
+ "step0": [_Batch(seq_no=0, step_name="step0", last_batch=True)]
+ },
+ )
+ },
+ last_batch_received={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True)
+ },
+ last_batch_sent={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True)
+ },
+ last_batch_flag_sent_to=["step1"],
+ )
+
+ dag = DAG()
+ dag.add_step(DummyStep1(name="step1"))
+
+ batch_manager._reset_batch_manager_for_step("step1", dag)
+ assert batch_manager._steps["step1"].data == {}
+ assert batch_manager._last_batch_received["step1"] is None
+ assert batch_manager._last_batch_sent["step1"] is None
+ assert batch_manager._last_batch_flag_sent_to == []
+
+ def test_load_predecessor_batches(self) -> None:
+ with Pipeline() as pipeline:
+ generator = DummyGeneratorStep()
+ step_a = DummyStep1()
+ step_b = DummyStep1()
+ step_c = DummyStep2()
+
+ generator >> [step_a, step_b] >> step_c
+
+ pipeline._load_batch_manager()
+ batch_manager: "_BatchManager" = pipeline._batch_manager # type: ignore
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ previous_step_dir = (
+ Path(tmp_dir) / f"{generator.name}_{generator.signature}"
+ ) # type: ignore
+ batches = []
+ for i in range(3):
+ batch = _Batch(
+ seq_no=i,
+ step_name=generator.name, # type: ignore
+ data=[[{"a": i} for _ in range(5)]],
+ last_batch=i % 3 == 0,
+ )
+ batches.append(batch)
+ batch.save(path=previous_step_dir / f"batch_{i}.json")
+
+ batch_manager._load_predecessor_batches(
+ step_name=step_a.name, # type: ignore
+ dag=pipeline.dag,
+ steps_data_path=Path(tmp_dir), # type: ignore
+ )
+
+ assert batch_manager._steps[step_a.name].data[generator.name] == batches # type: ignore
+ assert generator.name in batch_manager._steps[step_a.name].last_batch_received # type: ignore
+
def test_dump(self) -> None:
built_batch = _Batch(
seq_no=0,
@@ -1649,6 +1908,9 @@ def test_dump(self) -> None:
"step1": (1, 1),
"step2": (1, 1),
},
+ "step_offset": {},
+ "step_signature": None,
+ "use_cache": False,
"type_info": {
"module": "distilabel.pipeline.batch_manager",
"name": "_BatchManagerStep",
@@ -1866,467 +2128,3 @@ def test_from_dict(self) -> None:
assert isinstance(step, _Batch)
assert batch_manager._last_batch_flag_sent_to == ["step3"]
-
- def test_cache(self) -> None:
- batch_manager = _BatchManager.from_dict(
- {
- "steps": {
- "step1": {
- "step_name": "step1",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": None,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- "step2": {
- "step_name": "step2",
- "accumulate": False,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": 50,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- },
- "last_batch_received": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- },
- "last_batch_sent": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- },
- "last_batch_flag_sent_to": ["step3"],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManager",
- },
- }
- )
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- batch_manager_path = Path(tmp_dir) / "batch_manager.json"
- batch_manager.cache(batch_manager_path)
-
- assert batch_manager_path.exists() and batch_manager_path.is_file()
-
- for step_name, step in batch_manager._steps.items():
- batch_manager_step_dir = (
- Path(tmp_dir) / "batch_manager_steps" / step_name
- )
- assert (
- batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir()
- )
-
- batch_manager_step_path = (
- batch_manager_step_dir / "batch_manager_step.json"
- )
- assert (
- batch_manager_step_path.exists()
- and batch_manager_step_path.is_file()
- )
-
- built_batches_dir = batch_manager_step_dir / "built_batches"
- assert built_batches_dir.exists()
-
- for batch in step.built_batches:
- batch_path = (
- built_batches_dir
- / f"batch_{batch.seq_no}_{batch.data_hash}.json"
- )
- assert batch_path.exists() and batch_path.is_file()
-
- for buffered_step_name in step.data:
- buffered_step_dir = batch_manager_step_dir / buffered_step_name
- assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
-
- for batch in step.data[buffered_step_name]:
- batch_path = (
- buffered_step_dir
- / f"batch_{batch.seq_no}_{batch.data_hash}.json"
- )
- assert batch_path.exists() and batch_path.is_file()
-
- def test_load_from_cache(self) -> None:
- batch_manager = _BatchManager.from_dict(
- {
- "steps": {
- "step1": {
- "step_name": "step1",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": None,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- "step2": {
- "step_name": "step2",
- "accumulate": False,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": 50,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- },
- "last_batch_received": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- },
- "last_batch_sent": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- },
- "last_batch_flag_sent_to": ["step3"],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManager",
- },
- }
- )
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- batch_manager_path = Path(tmp_dir) / "batch_manager.json"
- batch_manager.cache(batch_manager_path)
- loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path)
-
- assert batch_manager.dump() == loaded_batch_manager.dump()
diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py
index 51c5aa2f46..a5b55520f4 100644
--- a/tests/unit/pipeline/test_dag.py
+++ b/tests/unit/pipeline/test_dag.py
@@ -18,9 +18,9 @@
import pytest
+from distilabel.constants import STEP_ATTR_NAME
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.pipeline._dag import DAG
-from distilabel.pipeline.constants import STEP_ATTR_NAME
from distilabel.pipeline.local import Pipeline
from distilabel.pipeline.routing_batch_function import routing_batch_function
from distilabel.steps.base import GeneratorStep, Step, StepInput, StepResources
@@ -33,6 +33,11 @@
StepOutput,
)
+import base64
+from unittest.mock import MagicMock, patch
+
+import requests
+
class TestDAG:
def test_add_step(self, dummy_step_1: "Step") -> None:
@@ -276,6 +281,37 @@ def test_get_steps_load_stages(self) -> None:
],
)
+ def test_get_steps_load_stages_global_steps_chained(self) -> None:
+ with Pipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep(name="dummy_generator_step")
+ dummies_0 = [DummyStep1(name=f"dummy_step_0_{i}") for i in range(3)]
+ global_0 = DummyGlobalStep(name="global_0")
+ global_1 = DummyGlobalStep(name="global_1")
+
+ generator >> dummies_0 >> global_0 >> global_1
+
+ assert pipeline.dag.get_steps_load_stages() == (
+ [
+ [
+ "dummy_generator_step",
+ "dummy_step_0_0",
+ "dummy_step_0_1",
+ "dummy_step_0_2",
+ ],
+ ["global_0"],
+ ["global_1"],
+ ],
+ [
+ [
+ "dummy_step_0_0",
+ "dummy_step_0_1",
+ "dummy_step_0_2",
+ ],
+ ["global_0"],
+ ["global_1"],
+ ],
+ )
+
def test_get_steps_load_stages_simple(self) -> None:
with Pipeline(name="dummy") as pipeline:
generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -776,3 +812,126 @@ def test_dag_to_from_file_format(
with Pipeline(name="unit-test-pipeline"):
dag_from_file = loader(filename)
assert isinstance(dag_from_file, DAG)
+
+
+class TestDAGDraw:
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_basic(self, mock_get):
+ # Mock the response from mermaid.ink
+ mock_response = MagicMock()
+ mock_response.content = b"mocked_image_content"
+ mock_get.return_value = mock_response
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ step1 = DummyStep1(name="step1")
+ step2 = DummyStep2(name="step2")
+
+ dag.add_step(generator_step)
+ dag.add_step(step1)
+ dag.add_step(step2)
+ dag.add_edge("generator", "step1")
+ dag.add_edge("step1", "step2")
+
+ image_content = dag.draw()
+
+ assert image_content == b"mocked_image_content"
+ mock_get.assert_called_once()
+ called_url = mock_get.call_args[0][0]
+ assert "https://mermaid.ink/img/" in called_url
+
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_top_to_bottom(self, mock_get):
+ mock_response = MagicMock()
+ mock_response.content = b"mocked_image_content"
+ mock_get.return_value = mock_response
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ step1 = DummyStep1(name="step1")
+ dag.add_step(generator_step)
+ dag.add_step(step1)
+ dag.add_edge("generator", "step1")
+
+ dag.draw(top_to_bottom=True)
+
+ called_url = mock_get.call_args[0][0]
+ decoded_graph = base64.b64decode(
+ called_url.split("/")[-1].split("?")[0]
+ ).decode("ascii")
+ assert "flowchart TD" in decoded_graph
+
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_without_edge_labels(self, mock_get):
+ mock_response = MagicMock()
+ mock_response.content = b"mocked_image_content"
+ mock_get.return_value = mock_response
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ step1 = DummyStep1(name="step1")
+ dag.add_step(generator_step)
+ dag.add_step(step1)
+ dag.add_edge("generator", "step1")
+
+ dag.draw(show_edge_labels=False)
+
+ called_url = mock_get.call_args[0][0]
+ decoded_graph = base64.b64decode(
+ called_url.split("/")[-1].split("?")[0]
+ ).decode("ascii")
+ assert "generator --> step1" in decoded_graph
+ assert "|" not in decoded_graph # No edge labels
+
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_with_argilla_step(self, mock_get):
+ mock_response = MagicMock()
+ mock_response.content = b"mocked_image_content"
+ mock_get.return_value = mock_response
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ step1 = DummyStep1(name="to_argilla")
+ dag.add_step(generator_step)
+ dag.add_step(step1)
+ dag.add_edge("generator", "to_argilla")
+
+ dag.draw()
+
+ called_url = mock_get.call_args[0][0]
+ decoded_graph = base64.b64decode(
+ called_url.split("/")[-1].split("?")[0]
+ ).decode("ascii")
+ assert 'to_argilla_0["Argilla"]' in decoded_graph
+
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_with_distiset_step(self, mock_get):
+ mock_response = MagicMock()
+ mock_response.content = b"mocked_image_content"
+ mock_get.return_value = mock_response
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ step1 = DummyStep1(name="step1")
+ dag.add_step(generator_step)
+ dag.add_step(step1)
+ dag.add_edge("generator", "step1")
+
+ dag.draw()
+
+ called_url = mock_get.call_args[0][0]
+ decoded_graph = base64.b64decode(
+ called_url.split("/")[-1].split("?")[0]
+ ).decode("ascii")
+ assert 'distiset_0["Distiset"]' in decoded_graph
+
+ @patch("distilabel.pipeline._dag.requests.get")
+ def test_draw_error_handling(self, mock_get):
+ mock_get.side_effect = requests.RequestException("Mocked error")
+
+ dag = DAG()
+ generator_step = DummyGeneratorStep(name="generator")
+ dag.add_step(generator_step)
+
+ with pytest.raises(ValueError, match="Error accessing https://mermaid.ink/"):
+ dag.draw()
diff --git a/tests/unit/pipeline/test_ray.py b/tests/unit/pipeline/test_ray.py
index 127f9eceb4..610f272196 100644
--- a/tests/unit/pipeline/test_ray.py
+++ b/tests/unit/pipeline/test_ray.py
@@ -16,11 +16,13 @@
import pytest
+from distilabel.errors import DistilabelUserError
from distilabel.llms.vllm import vLLM
from distilabel.pipeline.ray import RayPipeline
from distilabel.steps.base import StepResources
from distilabel.steps.tasks.text_generation import TextGeneration
from distilabel.utils.serialization import TYPE_INFO_KEY
+from tests.unit.conftest import DummyAsyncLLM, DummyTaskOfflineBatchGeneration
@pytest.fixture
@@ -56,6 +58,18 @@ def test_dump(self) -> None:
"name": "Pipeline",
}
+ def test_check_no_llms_using_offline_batch_generation(self) -> None:
+ with RayPipeline(name="unit-test") as pipeline:
+ DummyTaskOfflineBatchGeneration(
+ name="unit-test", llm=DummyAsyncLLM(use_offline_batch_generation=True)
+ )
+
+ with pytest.raises(
+ DistilabelUserError,
+ match="Step 'unit-test' uses an `LLM` with offline batch generation",
+ ):
+ pipeline._check_no_llms_using_offline_batch_generation()
+
def test_get_ray_gpus_per_node(self) -> None:
pipeline = RayPipeline(name="unit-test")
pipeline._init_ray()
diff --git a/tests/unit/pipeline/test_write_buffer.py b/tests/unit/pipeline/test_write_buffer.py
index 7c638550aa..cbd717c99e 100644
--- a/tests/unit/pipeline/test_write_buffer.py
+++ b/tests/unit/pipeline/test_write_buffer.py
@@ -15,6 +15,7 @@
import tempfile
from pathlib import Path
+from distilabel.constants import STEPS_OUTPUTS_PATH
from distilabel.distiset import Distiset, create_distiset
from distilabel.pipeline.local import Pipeline
from distilabel.pipeline.write_buffer import _WriteBuffer
@@ -29,7 +30,8 @@
class TestWriteBuffer:
def test_create(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
- folder = Path(tmpdirname) / "data"
+ folder = Path(tmpdirname) / "data" / STEPS_OUTPUTS_PATH
+ steps_outputs = folder / STEPS_OUTPUTS_PATH
with Pipeline(name="unit-test-pipeline") as pipeline:
dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
@@ -42,7 +44,9 @@ def test_create(self) -> None:
dummy_step_1.connect(dummy_step_2)
dummy_step_1.connect(dummy_step_3)
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+ write_buffer = _WriteBuffer(
+ path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps
+ )
assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []}
assert write_buffer._buffers_dump_batch_size == {
@@ -58,6 +62,7 @@ def test_create(self) -> None:
def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
folder = Path(tmpdirname) / "data"
+ steps_outputs = folder / STEPS_OUTPUTS_PATH
with Pipeline(name="unit-test-pipeline") as pipeline:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
dummy_step_1 = DummyStep1(name="dummy_step_1")
@@ -66,7 +71,9 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
dummy_generator.connect(dummy_step_1)
dummy_step_1.connect(dummy_step_2)
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+ write_buffer = _WriteBuffer(
+ path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps
+ )
# Add one batch with 5 rows, shouldn't write anything 5 < 50
batch = batch_gen(dummy_step_2.name) # type: ignore
@@ -77,14 +84,14 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
batch = batch_gen(dummy_step_2.name) # type: ignore
write_buffer.add_batch(batch)
- assert Path(folder, "dummy_step_2", "00001.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_2", "00001.parquet").exists()
# Add 50 more rows, we should have a new file
for _ in range(10):
batch = batch_gen(dummy_step_2.name) # type: ignore
write_buffer.add_batch(batch)
- assert Path(folder, "dummy_step_2", "00002.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_2", "00002.parquet").exists()
# Add more rows and close the write buffer, we should have a new file
for _ in range(5):
@@ -93,9 +100,9 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
write_buffer.close()
- assert Path(folder, "dummy_step_2", "00003.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_2", "00003.parquet").exists()
- ds = create_distiset(write_buffer._path)
+ ds = create_distiset(folder)
assert isinstance(ds, Distiset)
assert len(ds.keys()) == 1
assert len(ds["default"]["train"]) == 125
@@ -103,6 +110,7 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
folder = Path(tmpdirname) / "data"
+ steps_outputs = folder / STEPS_OUTPUTS_PATH
with Pipeline(name="unit-test-pipeline") as pipeline:
dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
@@ -115,19 +123,21 @@ def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None:
dummy_step_1.connect(dummy_step_2)
dummy_step_1.connect(dummy_step_3)
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+ write_buffer = _WriteBuffer(
+ path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps
+ )
for _ in range(10):
batch = batch_gen(dummy_step_2.name) # type: ignore
write_buffer.add_batch(batch)
- assert Path(folder, "dummy_step_2", "00001.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_2", "00001.parquet").exists()
for _ in range(10):
batch = batch_gen(dummy_step_3.name) # type: ignore
write_buffer.add_batch(batch)
- assert Path(folder, "dummy_step_3", "00001.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_3", "00001.parquet").exists()
for _ in range(5):
batch = batch_gen(dummy_step_2.name) # type: ignore
@@ -139,10 +149,10 @@ def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None:
write_buffer.close()
- assert Path(folder, "dummy_step_2", "00002.parquet").exists()
- assert Path(folder, "dummy_step_3", "00002.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_2", "00002.parquet").exists()
+ assert Path(steps_outputs, "dummy_step_3", "00002.parquet").exists()
- ds = create_distiset(write_buffer._path)
+ ds = create_distiset(folder)
assert isinstance(ds, Distiset)
assert len(ds.keys()) == 2
assert len(ds["dummy_step_2"]["train"]) == 75
diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py
index 937a9c68bd..cb223755aa 100644
--- a/tests/unit/pipeline/utils.py
+++ b/tests/unit/pipeline/utils.py
@@ -42,6 +42,8 @@ def outputs(self) -> List[str]:
class DummyStep1(Step):
+ attr1: int = 5
+
@property
def inputs(self) -> List[str]:
return ["instruction"]
diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py
index 78d70dd162..c0a452e72b 100644
--- a/tests/unit/steps/argilla/test_base.py
+++ b/tests/unit/steps/argilla/test_base.py
@@ -188,6 +188,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "tests.unit.steps.argilla.test_base",
"name": "CustomArgilla",
diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py
index 398ee58d34..ab63ee5419 100644
--- a/tests/unit/steps/argilla/test_preference.py
+++ b/tests/unit/steps/argilla/test_preference.py
@@ -26,7 +26,7 @@
@pytest.fixture
def mock_dataset() -> rg.Dataset: # type: ignore
rg.Argilla._validate_connection = mock.MagicMock() # type: ignore
- client = rg.Argilla(api_url="", api_key="")
+ client = rg.Argilla(api_url="https://example.com", api_key="")
return rg.Dataset(
name="dataset",
settings=rg.Settings(
@@ -180,6 +180,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.argilla.preference",
"name": "PreferenceToArgilla",
diff --git a/tests/unit/steps/argilla/test_text_generation.py b/tests/unit/steps/argilla/test_text_generation.py
index 8071649028..356bf5a2e7 100644
--- a/tests/unit/steps/argilla/test_text_generation.py
+++ b/tests/unit/steps/argilla/test_text_generation.py
@@ -26,7 +26,7 @@
@pytest.fixture
def mock_dataset() -> rg.Dataset:
rg.Argilla._validate_connection = mock.MagicMock() # type: ignore
- client = rg.Argilla(api_url="", api_key="")
+ client = rg.Argilla(api_url="https://example.com", api_key="")
return rg.Dataset(
name="dataset",
settings=rg.Settings(
@@ -155,6 +155,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.argilla.text_generation",
"name": "TextGenerationToArgilla",
diff --git a/tests/unit/steps/clustering/__init__.py b/tests/unit/steps/clustering/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/tests/unit/steps/clustering/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/steps/clustering/test_dbscan.py b/tests/unit/steps/clustering/test_dbscan.py
new file mode 100644
index 0000000000..d4f62a3fae
--- /dev/null
+++ b/tests/unit/steps/clustering/test_dbscan.py
@@ -0,0 +1,39 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from distilabel.steps.clustering.dbscan import DBSCAN
+
+
+class TestDBSCAN:
+ def test_process(self) -> None:
+ step = DBSCAN(n_jobs=1, eps=0.5, min_samples=5)
+ step.load()
+
+ results = next(
+ step.process(
+ inputs=[
+ {"projection": [0.1, -0.4]},
+ {"projection": [-0.3, 0.9]},
+ {"projection": [0.6, 0.2]},
+ {"projection": [-0.2, -0.6]},
+ {"projection": [0.9, 0.1]},
+ {"projection": [0.4, -0.7]},
+ {"projection": [-0.5, 0.3]},
+ {"projection": [0.7, 0.5]},
+ {"projection": [-0.1, -0.9]},
+ ]
+ )
+ )
+ assert all(result["cluster_label"] == -1 for result in results)
diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py
new file mode 100644
index 0000000000..4b2da96d40
--- /dev/null
+++ b/tests/unit/steps/clustering/test_text_clustering.py
@@ -0,0 +1,75 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import TYPE_CHECKING
+
+import pytest
+
+from distilabel.steps.clustering.text_clustering import TextClustering
+from tests.unit.conftest import DummyAsyncLLM
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+
+class ClusteringLLM(DummyAsyncLLM):
+ n: int = 1
+
+ async def agenerate( # type: ignore
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ if self.n == 1:
+ return [json.dumps({"labels": "label"}) for _ in range(num_generations)]
+ return [
+ json.dumps({"labels": ["label" for _ in range(self.n)]})
+ for _ in range(self.n)
+ ]
+
+
+class TestTextClustering:
+ @pytest.mark.parametrize("n", [1, 3])
+ def test_process(self, n: int) -> None:
+ step = TextClustering(
+ llm=ClusteringLLM(n=n),
+ n=n,
+ samples_per_cluster=2,
+ savefig=False,
+ )
+ step.load()
+
+ results = next(
+ step.process(
+ inputs=[
+ {"projection": [0.1, -0.4], "cluster_label": -1, "text": "hello"},
+ {"projection": [-0.3, 0.9], "cluster_label": -1, "text": "hello"},
+ {"projection": [0.6, 0.2], "cluster_label": 0, "text": "hello"},
+ {"projection": [-0.2, -0.6], "cluster_label": 0, "text": "hello"},
+ {"projection": [0.9, 0.1], "cluster_label": 0, "text": "hello"},
+ {"projection": [0.4, -0.7], "cluster_label": 1, "text": "hello"},
+ {"projection": [-0.5, 0.3], "cluster_label": 1, "text": "hello"},
+ {"projection": [0.7, 0.5], "cluster_label": 2, "text": "hello"},
+ {"projection": [-0.1, -0.9], "cluster_label": 2, "text": "hello"},
+ ]
+ )
+ )
+ for r in results:
+ if r["cluster_label"] == -1:
+ assert r["summary_label"] == json.dumps("Unclassified")
+ else:
+ if n == 1:
+ assert r["summary_label"] == json.dumps("label")
+ else:
+ assert r["summary_label"] == json.dumps(["label"] * n)
diff --git a/tests/unit/steps/clustering/test_umap.py b/tests/unit/steps/clustering/test_umap.py
new file mode 100644
index 0000000000..3ab252fd24
--- /dev/null
+++ b/tests/unit/steps/clustering/test_umap.py
@@ -0,0 +1,42 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from distilabel.steps.clustering.umap import UMAP
+
+
+class TestUMAP:
+ def test_process(self) -> None:
+ n_components = 2
+ step = UMAP(n_jobs=1, n_components=n_components)
+ step.load()
+
+ results = next(
+ step.process(
+ inputs=[
+ {"embedding": [0.1, -0.4, 0.7, 0.2]},
+ {"embedding": [-0.3, 0.9, 0.1, -0.5]},
+ {"embedding": [0.6, 0.2, -0.1, 0.8]},
+ {"embedding": [-0.2, -0.6, 0.4, 0.3]},
+ {"embedding": [0.9, 0.1, -0.3, -0.2]},
+ {"embedding": [0.4, -0.7, 0.6, 0.1]},
+ {"embedding": [-0.5, 0.3, -0.2, 0.9]},
+ {"embedding": [0.7, 0.5, -0.4, -0.1]},
+ {"embedding": [-0.1, -0.9, 0.8, 0.6]},
+ ]
+ )
+ )
+ assert all(isinstance(result["projection"], np.ndarray) for result in results)
+ assert all(len(result["projection"]) == n_components for result in results)
diff --git a/tests/unit/steps/columns/__init__.py b/tests/unit/steps/columns/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/tests/unit/steps/columns/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/steps/columns/test_combine.py b/tests/unit/steps/columns/test_combine.py
new file mode 100644
index 0000000000..817d89e90b
--- /dev/null
+++ b/tests/unit/steps/columns/test_combine.py
@@ -0,0 +1,54 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.constants import DISTILABEL_METADATA_KEY
+from distilabel.steps.columns.combine import CombineOutputs
+
+
+class TestCombineOutputs:
+ def test_process(self) -> None:
+ combine = CombineOutputs()
+
+ output = next(
+ combine.process(
+ [
+ {
+ "a": 1,
+ "b": 2,
+ DISTILABEL_METADATA_KEY: {"model": "model-1", "a": 1},
+ }
+ ],
+ [
+ {
+ "c": 3,
+ "d": 4,
+ DISTILABEL_METADATA_KEY: {"model": "model-2", "b": 1},
+ }
+ ],
+ )
+ )
+
+ assert output == [
+ {
+ "a": 1,
+ "b": 2,
+ "c": 3,
+ "d": 4,
+ DISTILABEL_METADATA_KEY: {
+ "model": ["model-1", "model-2"],
+ "a": 1,
+ "b": 1,
+ },
+ }
+ ]
diff --git a/tests/unit/steps/columns/test_group.py b/tests/unit/steps/columns/test_group.py
index 258029d7b9..57f9f114de 100644
--- a/tests/unit/steps/columns/test_group.py
+++ b/tests/unit/steps/columns/test_group.py
@@ -15,6 +15,7 @@
import pytest
+from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.pipeline.local import Pipeline
from distilabel.steps.columns.group import CombineColumns, GroupColumns
@@ -44,8 +45,19 @@ def test_process(self) -> None:
columns=["a", "b"],
pipeline=Pipeline(name="unit-test-pipeline"),
)
- output = next(group.process([{"a": 1, "b": 2}], [{"a": 3, "b": 4}]))
- assert output == [{"grouped_a": [1, 3], "grouped_b": [2, 4]}]
+ output = next(
+ group.process(
+ [{"a": 1, "b": 2, DISTILABEL_METADATA_KEY: {"model": "model-1"}}],
+ [{"a": 3, "b": 4, DISTILABEL_METADATA_KEY: {"model": "model-2"}}],
+ )
+ )
+ assert output == [
+ {
+ "grouped_a": [1, 3],
+ "grouped_b": [2, 4],
+ DISTILABEL_METADATA_KEY: {"model": ["model-1", "model-2"]},
+ }
+ ]
def test_CombineColumns_deprecation_warning():
diff --git a/tests/unit/steps/filtering/__init__.py b/tests/unit/steps/filtering/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/tests/unit/steps/filtering/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/steps/filtering/test_embeddings.py b/tests/unit/steps/filtering/test_embeddings.py
new file mode 100644
index 0000000000..354777bd94
--- /dev/null
+++ b/tests/unit/steps/filtering/test_embeddings.py
@@ -0,0 +1,104 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from distilabel.steps.filtering.embedding import EmbeddingDedup
+
+SAMPLE_DATA = [
+ {
+ "persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.",
+ "embedding": [
+ 0.018477669046149742,
+ -0.03748236608841726,
+ 0.001919870620352492,
+ 0.024918478063770535,
+ 0.02348063521315178,
+ 0.0038251285566308375,
+ -0.01723884983037716,
+ 0.02881971942372201,
+ ],
+ "nn_indices": [0, 1],
+ "nn_scores": [
+ 0.9164746999740601,
+ 0.782106876373291,
+ ],
+ },
+ {
+ "persona": "A music teacher or instructor focused on theoretical and practical piano lessons.",
+ "embedding": [
+ -0.0023464179614082125,
+ -0.07325472251663565,
+ -0.06058678419516501,
+ -0.02100326928586996,
+ -0.013462744792362657,
+ 0.027368447064244242,
+ -0.003916070100455717,
+ 0.01243614518480423,
+ ],
+ "nn_indices": [0, 2],
+ "nn_scores": [
+ 0.7552462220191956,
+ 0.7261884808540344,
+ ],
+ },
+ {
+ "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
+ "embedding": [
+ -0.01630817942328242,
+ -0.023760151552345232,
+ -0.014249650090627883,
+ -0.005713686451446624,
+ -0.016033059279131567,
+ 0.0071440908501058786,
+ -0.05691099643425161,
+ 0.01597412704817784,
+ ],
+ "nn_indices": [1, 2],
+ "nn_scores": [
+ 0.8107735514640808,
+ 0.7172299027442932,
+ ],
+ },
+ {
+ "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
+ "embedding": [
+ -0.01630817942328242,
+ -0.023760151552345232,
+ -0.014249650090627883,
+ -0.005713686451446624,
+ -0.016033059279131567,
+ 0.0071440908501058786,
+ -0.05691099643425161,
+ 0.01597412704817784,
+ ],
+ "nn_indices": [],
+ "nn_scores": [],
+ },
+]
+
+
+class TestEmbeddingDedup:
+ @pytest.mark.parametrize(
+ "threshold, keep_row_after_embedding_filtering",
+ [(0.1, 1), (0.9, 3), (0.99999, 4)],
+ )
+ def test_process(
+ self, threshold: float, keep_row_after_embedding_filtering: int
+ ) -> None:
+ step = EmbeddingDedup(threshold=threshold)
+ step.load()
+ result = next(step.process(SAMPLE_DATA))
+ duplicated = [r["keep_row_after_embedding_filtering"] for r in result]
+ assert sum(duplicated) == keep_row_after_embedding_filtering
diff --git a/tests/unit/steps/filtering/test_minhash.py b/tests/unit/steps/filtering/test_minhash.py
new file mode 100644
index 0000000000..48b765c1a9
--- /dev/null
+++ b/tests/unit/steps/filtering/test_minhash.py
@@ -0,0 +1,66 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import nltk
+import pytest
+
+from distilabel.steps.filtering.minhash import (
+ MinHashDedup,
+ tokenize_on_ngrams,
+ tokenized_on_words,
+)
+
+nltk.download("punkt_tab")
+
+texts: List[str] = [
+ "This is a test document.",
+ "This document is a test.",
+ "Test document for duplication.",
+ "Document for duplication test.",
+ "This is another unique document.",
+]
+
+
+def test_tokenize_on_words() -> None:
+ tokenized = tokenized_on_words(texts)
+ assert len(tokenized) == len(texts)
+ assert tokenized[0] == {b".", b"This", b"a", b"document", b"is", b"test"}
+
+
+@pytest.mark.parametrize("n", [1, 3])
+def test_tokenize_on_ngrams(n: int) -> None:
+ tokenized = tokenize_on_ngrams(texts, n=n)
+ assert len(tokenized) == len(texts)
+ assert all(len(t) == n for t in tokenized[0])
+
+
+class TestMinHashDedup:
+ @pytest.mark.parametrize(
+ "threshold, keep_row_after_minhash_filtering, storage",
+ [(0.1, 1, "dict"), (0.9, 4, "dict"), (0.9, 4, "disk")],
+ )
+ def test_process(
+ self, threshold: float, keep_row_after_minhash_filtering: int, storage: str
+ ) -> None:
+ msh = MinHashDedup(
+ threshold=threshold,
+ storage=storage,
+ )
+ msh.load()
+ result = next(msh.process([{"text": t} for t in texts]))
+ duplicated = [r["keep_row_after_minhash_filtering"] for r in result]
+ assert sum(duplicated) == keep_row_after_minhash_filtering
+ msh.unload()
diff --git a/tests/unit/steps/generators/test_data.py b/tests/unit/steps/generators/test_data.py
index 9684d5abd0..3767451991 100644
--- a/tests/unit/steps/generators/test_data.py
+++ b/tests/unit/steps/generators/test_data.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-from pydantic import ValidationError
from distilabel.pipeline.local import Pipeline
from distilabel.steps.generators.data import LoadDataFromDicts
@@ -30,11 +29,6 @@ def test_init(self) -> None:
assert task.data == data
assert task.batch_size == 10
- def test_with_errors(self) -> None:
- pipeline = Pipeline(name="unit-test-pipeline")
- with pytest.raises(ValidationError):
- LoadDataFromDicts(name="task", pipeline=pipeline)
-
def test_process(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
data: list[dict[str, str]] = self.data
diff --git a/tests/unit/steps/generators/test_data_sampler.py b/tests/unit/steps/generators/test_data_sampler.py
new file mode 100644
index 0000000000..32882e0379
--- /dev/null
+++ b/tests/unit/steps/generators/test_data_sampler.py
@@ -0,0 +1,45 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import pytest
+
+from distilabel.steps.generators.data_sampler import DataSampler
+
+
+@pytest.mark.parametrize(
+ "samples, size, batch_size, expected",
+ [
+ (10, 2, 4, [4, 4, 2]),
+ (7, 5, 6, [6, 1]),
+ (20, 5, 20, [20]),
+ (20, 50, 8, [8, 8, 4]),
+ ],
+)
+def test_generator_and_sampler(
+ samples: int, size: int, batch_size: int, expected: List[int]
+):
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ size=size,
+ samples=samples,
+ batch_size=batch_size,
+ )
+ sampler.load()
+ results = [item[0] for item in sampler.process()]
+ assert len(results) == len(expected)
+ assert len(results[0]) == batch_size
+ for i, result in enumerate(results):
+ assert len(result) == expected[i]
diff --git a/tests/unit/steps/generators/test_huggingface.py b/tests/unit/steps/generators/test_huggingface.py
index d115f8cf8f..fbae2a86c1 100644
--- a/tests/unit/steps/generators/test_huggingface.py
+++ b/tests/unit/steps/generators/test_huggingface.py
@@ -27,6 +27,7 @@
LoadDataFromFileSystem,
LoadDataFromHub,
)
+from tests.unit.pipeline.utils import DummyStep1
DISTILABEL_RUN_SLOW_TESTS = os.getenv("DISTILABEL_RUN_SLOW_TESTS", False)
@@ -103,7 +104,7 @@ def test_read_from_jsonl_with_folder(self, filetype: Union[str, None]) -> None:
loader = LoadDataFromFileSystem(
filetype=filetype,
- data_files=tmpdir,
+ data_files=str(Path(tmpdir) / "*.jsonl"),
)
loader.load()
generator_step_output = next(loader.process())
@@ -126,7 +127,7 @@ def test_read_from_jsonl_with_nested_folder(
loader = LoadDataFromFileSystem(
filetype=filetype,
- data_files=tmpdir,
+ data_files=str(Path(tmpdir) / "**/*.jsonl"),
)
loader.load()
generator_step_output = next(loader.process())
@@ -134,18 +135,23 @@ def test_read_from_jsonl_with_nested_folder(
assert isinstance(generator_step_output[1], bool)
assert len(generator_step_output[0]) == 22
- @pytest.mark.parametrize("load", [True, False])
- def test_outputs(self, load: bool) -> None:
+ def test_outputs(self) -> None:
loader = LoadDataFromFileSystem(
filetype="json",
data_files=str(Path(__file__).parent / "sample_functions.jsonl"),
)
- if load:
- loader.load()
- assert loader.outputs == ["type", "function"]
- else:
- with pytest.raises(ValueError):
- loader.outputs # noqa: B018
+ loader.load()
+ assert loader.outputs == ["type", "function"]
+
+ def test_loading_in_pipeline(self):
+ with Pipeline():
+ loader = LoadDataFromFileSystem(
+ filetype="json",
+ data_files=str(Path(__file__).parent / "sample_functions.jsonl"),
+ )
+ dummy = DummyStep1(input_mappings={"instruction": "function"})
+ loader >> dummy
+ assert loader.outputs == ["type", "function"]
class TestLoadDataFromDisk:
@@ -162,6 +168,32 @@ def test_load_dataset_from_disk(self) -> None:
assert isinstance(generator_step_output[1], bool)
assert len(generator_step_output[0]) == 3
+ @pytest.mark.parametrize("config_name", ["default", "missnamed_config"])
+ def test_load_distiset_from_disk_default(self, config_name: str) -> None:
+ distiset = Distiset(
+ {
+ "default": Dataset.from_dict({"a": [1, 2, 3]}),
+ }
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dataset_path = str(Path(tmpdir) / "dataset_path")
+ distiset.save_to_disk(dataset_path)
+
+ loader = LoadDataFromDisk(
+ dataset_path=dataset_path,
+ is_distiset=True,
+ config=config_name,
+ )
+ if config_name != "default":
+ with pytest.raises(ValueError):
+ loader.load()
+ else:
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 3
+
def test_load_distiset_from_disk(self) -> None:
distiset = Distiset(
{
diff --git a/tests/unit/steps/generators/test_utils.py b/tests/unit/steps/generators/test_utils.py
index 67323bb9c4..f25f260f79 100644
--- a/tests/unit/steps/generators/test_utils.py
+++ b/tests/unit/steps/generators/test_utils.py
@@ -18,7 +18,8 @@
import pytest
from datasets import Dataset
-from distilabel.steps import make_generator_step
+from distilabel.pipeline.local import Pipeline
+from distilabel.steps.generators.utils import make_generator_step
data = [{"instruction": "Tell me a joke."}] * 10
@@ -26,7 +27,7 @@
@pytest.mark.parametrize("dataset", (data, Dataset.from_list(data), pd.DataFrame(data)))
def test_make_generator_step(
dataset: Union[Dataset, pd.DataFrame, List[Dict[str, str]]],
-):
+) -> None:
batch_size = 5
load_dataset = make_generator_step(
dataset, batch_size=batch_size, output_mappings={"instruction": "other"}
@@ -40,3 +41,9 @@ def test_make_generator_step(
assert isinstance(load_dataset.data, list)
assert load_dataset.output_mappings == {"instruction": "other"}
+
+
+def test_make_generator_step_with_pipeline() -> None:
+ pipeline = Pipeline()
+ load_dataset = make_generator_step(data, pipeline=pipeline)
+ assert load_dataset.pipeline == pipeline
diff --git a/tests/unit/steps/tasks/apigen/__init__.py b/tests/unit/steps/tasks/apigen/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py
new file mode 100644
index 0000000000..abcc66214c
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py
@@ -0,0 +1,27 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ return initial_velocity + acceleration * time
diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py
new file mode 100644
index 0000000000..db3bd1bccf
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py
@@ -0,0 +1,33 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+
+def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
+ """Gets the value at the specified index in the matrix.
+
+ Args:
+ matrix: A list of lists representing the matrix.
+ indices: A tuple containing the row and column indices.
+ """
+ row_index, col_index = indices
+ if (
+ row_index < 0
+ or row_index >= len(matrix)
+ or col_index < 0
+ or col_index >= len(matrix[row_index])
+ ):
+ return None
+ return matrix[row_index][col_index]
diff --git a/tests/unit/steps/tasks/apigen/_sample_module.py b/tests/unit/steps/tasks/apigen/_sample_module.py
new file mode 100644
index 0000000000..6e9e085023
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_module.py
@@ -0,0 +1,47 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ return initial_velocity + acceleration * time
+
+
+def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
+ """Gets the value at the specified index in the matrix.
+
+ Args:
+ matrix: A list of lists representing the matrix.
+ indices: A tuple containing the row and column indices.
+ """
+ row_index, col_index = indices
+ if (
+ row_index < 0
+ or row_index >= len(matrix)
+ or col_index < 0
+ or col_index >= len(matrix[row_index])
+ ):
+ return None
+ return matrix[row_index][col_index]
diff --git a/tests/unit/steps/tasks/apigen/test_execution_checker.py b/tests/unit/steps/tasks/apigen/test_execution_checker.py
new file mode 100644
index 0000000000..d70e422715
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_execution_checker.py
@@ -0,0 +1,140 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from pathlib import Path
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker
+
+SAMPLE_LIB = Path(__file__).parent / "_sample_module.py"
+SAMPLE_LIB_FOLDER = Path(__file__).parent / "_sample_lib"
+
+
+class TestAPIGenExecutionChecker:
+ @pytest.mark.parametrize("lib", (SAMPLE_LIB, SAMPLE_LIB_FOLDER))
+ @pytest.mark.parametrize(
+ "answers, expected",
+ [
+ (
+ {
+ "query": "Whats the velocity of X?",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": "0.1",
+ "time": 5,
+ },
+ "name": "final_velocity",
+ }
+ ]
+ ),
+ },
+ [
+ {
+ "query": "Whats the velocity of X?",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": "0.1",
+ "time": 5,
+ },
+ "name": "final_velocity",
+ }
+ ]
+ ),
+ "keep_row_after_execution_check": True,
+ "execution_result": ["0.7"],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "unknown_function",
+ }
+ ]
+ ),
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "unknown_function",
+ }
+ ]
+ ),
+ "keep_row_after_execution_check": False,
+ "execution_result": ["Function 'unknown_function' not found."],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]',
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]',
+ "keep_row_after_execution_check": True,
+ "execution_result": ["6"],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": None,
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": None,
+ "keep_row_after_execution_check": False,
+ "execution_result": ["No answers were provided."],
+ }
+ ],
+ ),
+ ],
+ )
+ def test_process(
+ self, lib: str, answers: Dict[str, str], expected: Dict[str, Any]
+ ) -> None:
+ task = APIGenExecutionChecker(libpath=str(lib))
+ task.load()
+ result = next(task.process([answers]))
+ assert result == expected
diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py
new file mode 100644
index 0000000000..a290666a60
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_generator.py
@@ -0,0 +1,172 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import TYPE_CHECKING, List, Union
+
+import pytest
+
+from distilabel.steps.tasks.apigen.generator import APIGenGenerator
+from tests.unit.conftest import DummyLLM
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+import json
+
+
+class DummyAPIGenLLM(DummyLLM):
+ use_structured_output: bool = False
+ number: int = 1
+
+ def generate(
+ self, inputs: List["FormattedInput"], num_generations: int = 1
+ ) -> "GenerateOutput":
+ query_answers = [
+ {
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": [
+ {
+ "name": "get_breed_information",
+ "arguments": {"breed": "Maine Coon"},
+ }
+ ]
+ * self.number,
+ }
+ ]
+ if self.use_structured_output:
+ query_answers = {"pairs": query_answers}
+ return [
+ [json.dumps(query_answers) for _ in range(num_generations)]
+ for _ in range(len(inputs))
+ ]
+
+
+# Example of 3 rows from Salesforce/xlam-function-calling-60k
+SAMPLE_DATA = [
+ {
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "id": 3493,
+ "tools": '[{"name": "get_breed_information", "description": "Fetch information about a specific cat breed from the Cat Breeds API.", "parameters": {"breed": {"description": "The name of the cat breed to fetch information for.", "type": "str", "default": "aegean"}}}, {"name": "country_region_cities", "description": "Fetches a list of cities within a specified region of a given country from the GeoDB API.", "parameters": {"countryid": {"description": "An ISO-3166 country code or WikiData ID.", "type": "str", "default": "US"}, "regioncode": {"description": "An ISO-3166 or FIPS region code.", "type": "str", "default": "CA"}, "limit": {"description": "The maximum number of results to retrieve. Defaults to None.", "type": "int, optional", "default": ""}, "hateoasmode": {"description": "Include HATEOAS-style links in results. Defaults to None.", "type": "bool, optional", "default": ""}, "asciimode": {"description": "Display results using ASCII characters. Defaults to None.", "type": "bool, optional", "default": ""}, "nameprefixdefaultlangresults": {"description": "Match on names in the default language if a non-default language is requested when prefix-matching. Defaults to None.", "type": "bool, optional", "default": ""}, "timezoneids": {"description": "Only include cities in these time zones. Comma-separated values. Defaults to None.", "type": "str, optional", "default": ""}, "nameprefix": {"description": "Only include cities whose names start with this prefix. If languagecode is set, the prefix will be matched on the name as it appears in that language. Defaults to None.", "type": "str, optional", "default": ""}, "types": {"description": "Only include cities of these types (comma-separated): CITY, ADM2. Defaults to None.", "type": "str, optional", "default": ""}, "minpopulation": {"description": "Only include cities with at least this population. Defaults to None.", "type": "int, optional", "default": ""}, "languagecode": {"description": "Display results in this language. Defaults to None.", "type": "str, optional", "default": ""}, "offset": {"description": "The zero-based offset into the results. Defaults to None.", "type": "int, optional", "default": ""}, "maxpopulation": {"description": "Only include cities with no more than this population. Defaults to None.", "type": "int, optional", "default": ""}, "includedeleted": {"description": "Whether to include any cities marked deleted. Options are: ALL, SINCE_YESTERDAY, SINCE_LAST_WEEK, NONE. Defaults to None.", "type": "str, optional", "default": ""}, "sort": {"description": "How to sort the results. Format: \\u00b1SORT_FIELD,\\u00b1SORT_FIELD where SORT_FIELD = elevation, name, population. Defaults to None.", "type": "str, optional", "default": ""}}}, {"name": "company_details", "description": "Fetch details of a company from Indeed\'s API.", "parameters": {"company_id": {"description": "The unique identifier of the company to fetch details for.", "type": "str", "default": "Microsoft"}, "locality": {"description": "The locality or country code for Indeed\'s subdomain. Default is \'us\' if not provided.", "type": "str, optional", "default": ""}}}]',
+ },
+ {
+ "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]',
+ "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.",
+ "id": 57546,
+ "tools": '[{"name": "mailcheck", "description": "Checks if an email domain is valid or a disposable/temporary address.", "parameters": {"domain": {"description": "The email or domain to check for validity. It is recommended to enter just the domain for user privacy.", "type": "str", "default": "mailinator.com"}}}, {"name": "get_products_in_category", "description": "Fetches a list of products from a specified category in a store with pagination.", "parameters": {"skip": {"description": "The number of items to skip before starting to collect the result set.", "type": "int", "default": ""}, "limit": {"description": "The number of items to return in the result set.", "type": "int", "default": ""}, "category": {"description": "The category from which to fetch products.", "type": "str", "default": ""}}}, {"name": "product_by_id", "description": "Fetches detailed information about a specific product from the AliExpress API using the provided product ID.", "parameters": {"product_id": {"description": "The unique identifier for the product on AliExpress.", "type": "int", "default": "32841070485"}}}]',
+ },
+ {
+ "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]',
+ "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?",
+ "id": 8815,
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ },
+]
+
+
+class TestApiGenGenerator:
+ @pytest.mark.parametrize("number", [1, 2, [3]])
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ @pytest.mark.parametrize("use_tools", [True, False])
+ def test_format_input(
+ self,
+ number: Union[int, List[int]],
+ use_default_structured_output: bool,
+ use_tools: bool,
+ ) -> None:
+ random.seed(42)
+ task = APIGenGenerator(
+ llm=DummyLLM(),
+ number=number,
+ use_tools=use_tools,
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ formatted = task.format_input(
+ input={
+ "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "func_name": "get_breed_information",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ }
+ )
+
+ assert isinstance(formatted, list)
+ # Check only the user prompt, the system one should be fixed
+ formatted_prompt = formatted[1]["content"]
+
+ if isinstance(number, list):
+ # Fix the number for the tests for simplicity
+ number = 3
+ assert f"Now please generate {number} diverse" in formatted_prompt
+
+ assert (
+ "The output MUST strictly adhere to the following JSON format, and NO other text MUST be included:"
+ in formatted_prompt
+ )
+
+ tools_entry = "This is the available tool to guide you (respect the order of the parameters):"
+ if use_tools:
+ assert tools_entry in formatted_prompt
+ else:
+ assert tools_entry not in formatted_prompt
+
+ is_parallel_check = "It can contain multiple parallel queries in natural language for the given functions. They could use either the same function with different arguments or different functions."
+ if number > 1:
+ assert is_parallel_check in formatted_prompt
+ else:
+ assert is_parallel_check not in formatted_prompt
+
+ @pytest.mark.parametrize("number", [1, 2])
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ @pytest.mark.parametrize("use_tools", [True, False])
+ def test_process(
+ self,
+ number: Union[int, List[int]],
+ use_default_structured_output: bool,
+ use_tools: bool,
+ ) -> None:
+ # Is parallel is not relevant in this case, it's only relevant for the format_input
+ # as it will be multiple questions in the prompt
+ random.seed(42)
+ task = APIGenGenerator(
+ llm=DummyAPIGenLLM(
+ use_structured_output=use_default_structured_output, number=number
+ ),
+ number=number,
+ use_tools=use_tools,
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ result = next(
+ task.process(
+ [
+ {
+ "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "func_name": "get_breed_information",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ }
+ ]
+ )
+ )[0]
+ assert "query" in result
+ assert "answers" in result
+ query = result["query"]
+ assert isinstance(query, str)
+ answers = json.loads(result["answers"])
+ assert isinstance(answers, list)
+ assert len(answers) == number
diff --git a/tests/unit/steps/tasks/apigen/test_semantic_checker.py b/tests/unit/steps/tasks/apigen/test_semantic_checker.py
new file mode 100644
index 0000000000..e73b71c3a0
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_semantic_checker.py
@@ -0,0 +1,113 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker
+from tests.unit.conftest import DummyLLM
+
+SAMPLE_DATA = [
+ # The info can for the function description can be obtained from the tool itself
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ },
+ {
+ "func_desc": "Checks if an email domain is valid or a disposable/temporary address.",
+ "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.",
+ "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]',
+ "execution_result": "Response for the emails",
+ },
+ {
+ "func_desc": "Fetches the content of a node in a navigation hierarchy.",
+ "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?",
+ "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]',
+ "execution_result": "Response for the node contents",
+ },
+]
+
+
+class TestAPIGenSemanticChecker:
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ def test_format_input(self, use_default_structured_output: bool) -> None:
+ task = APIGenSemanticChecker(
+ llm=DummyLLM(),
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ result = task.format_input(SAMPLE_DATA[0])
+ assert isinstance(result, list)
+ formatted_prompt = result[1]["content"]
+
+ default_structured_output_check = "Your response MUST strictly adhere to the following JSON format, and NO other text MUST be included"
+ assert default_structured_output_check in formatted_prompt
+ assert (
+ '- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]'
+ in formatted_prompt
+ )
+ assert (
+ "- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API."
+ in formatted_prompt
+ )
+ assert (
+ "- Execution Results: Hopefully some info about the Maine Coon"
+ in formatted_prompt
+ )
+
+ @pytest.mark.parametrize(
+ "result, expected",
+ [
+ (
+ '{"thought": "thought", "keep_row_after_semantic_check": "no", "passes": "no"}',
+ {
+ "thought": "thought",
+ "keep_row_after_semantic_check": False,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ (
+ None,
+ {
+ "thought": None,
+ "keep_row_after_semantic_check": None,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ (
+ "wrong",
+ {
+ "thought": None,
+ "keep_row_after_semantic_check": None,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ ],
+ )
+ def test_format_output(self, result: str, expected: Dict[str, Any]) -> None:
+ task = APIGenSemanticChecker(llm=DummyLLM())
+ task.load()
+ assert task.format_output(result, SAMPLE_DATA[0]) == expected
diff --git a/tests/unit/steps/tasks/apigen/test_utils.py b/tests/unit/steps/tasks/apigen/test_utils.py
new file mode 100644
index 0000000000..00707f17a9
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_utils.py
@@ -0,0 +1,77 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.utils import (
+ execute_from_response,
+ load_module_from_path,
+)
+
+
+@pytest.mark.parametrize(
+ "function_name, answer, expected_result",
+ [
+ (
+ "final_velocity",
+ {"initial_velocity": 10, "acceleration": 5, "time": 2},
+ {"execution_result": "20", "keep": True},
+ ),
+ # In this case, internally we should cast the arguments
+ (
+ "final_velocity",
+ {"initial_velocity": "10", "acceleration": "5", "time": "2"},
+ {"execution_result": "20", "keep": True},
+ ),
+ # Different names for the arguments but correctly positioned
+ (
+ "final_velocity",
+ {"v0": "10", "a": "5", "t": "2"},
+ {"execution_result": "20", "keep": True},
+ ),
+ # Fail casting one of the values
+ (
+ "final_velocity",
+ {"initial_velocity": "10", "acceleration": "5", "time": "1m/s"},
+ {
+ "execution_result": "unsupported operand type(s) for +: 'int' and 'str'",
+ "keep": False,
+ },
+ ),
+ (
+ "final_velocity",
+ {"initial_velocity": 10, "acceleration": 5},
+ {
+ "execution_result": "final_velocity() missing 1 required positional argument: 'time'",
+ "keep": False,
+ },
+ ),
+ (
+ "unknwown_function",
+ {"initial_velocity": 10, "acceleration": 5, "time": 2},
+ {"execution_result": "Function not found", "keep": False},
+ ),
+ ],
+)
+def test_execute_from_response(
+ function_name: str, answer: Dict[str, Any], expected_result: Dict[str, Any]
+):
+ libpath = Path(__file__).parent / "_sample_module.py"
+ libpath = load_module_from_path(libpath)
+ final_velocity = getattr(libpath, function_name, None)
+ result = execute_from_response(final_velocity, answer)
+ assert result == expected_result
diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py
index 29edfb4c13..66f67347b1 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_base.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_base.py
@@ -123,6 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
assert task.dump() == {
"name": "task",
"add_raw_output": True,
+ "add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
@@ -135,6 +136,10 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
+ "structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
@@ -153,6 +158,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"INCREASED_REASONING_STEPS": "I want you act as a Prompt Rewriter.\n\nYour objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., chatgpt and GPT4) a bit harder to handle.\n\nBut the rewritten prompt must be reasonable and must be understood and responded by humans.\n\nYour rewriting cannot omit the non-text parts such as the table and code in #The Given Prompt#:. Also, please do not omit the input in #The Given Prompt#.\n\nYou SHOULD complicate the given prompt using the following method: \nIf #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning.\n\nYou should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #The Given Prompt#.\n\n'#The Given Prompt#', '#Rewritten Prompt#', 'given prompt' and 'rewritten prompt' are not allowed to appear in #Rewritten Prompt#\n\n#The Given Prompt#:\n\n#Rewritten Prompt#:\n\n",
"BREADTH": "I want you act as a Prompt Creator.\n\nYour goal is to draw inspiration from the #Given Prompt# to create a brand new prompt.\n\nThis new prompt should belong to the same domain as the #Given Prompt# but be even more rare.\n\nThe LENGTH and complexity of the #Created Prompt# should be similar to that of the #Given Prompt#.\n\nThe #Created Prompt# must be reasonable and must be understood and responded by humans.\n\n'#Given Prompt#', '#Created Prompt#', 'given prompt' and 'created prompt' are not allowed to appear in #Created Prompt#\n\n#Given Prompt#:\n\n#Created Prompt#:\n\n",
},
+ "use_default_structured_output": False,
"seed": task.seed,
"runtime_parameters_info": [
{
@@ -197,7 +203,21 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [],
- }
+ },
+ {
+ "description": "Whether to use the `offline_batch_generate` method to "
+ "generate the responses.",
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ },
+ {
+ "description": "If provided, then polling will be done until the "
+ "`ofline_batch_generate` method is able to retrieve the "
+ "results. The value indicate the time to wait between each "
+ "polling.",
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ },
],
},
{
@@ -205,6 +225,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
+ {
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ "name": "add_raw_input",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
@@ -216,6 +241,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to seed a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.evol_instruct.base",
"name": "EvolInstruct",
diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py
index 754fa846ed..8f86b94908 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_generator.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py
@@ -118,12 +118,17 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "task",
"llm": {
"generation_kwargs": {},
+ "structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": task.llm.__class__.__module__,
"name": task.llm.__class__.__name__,
},
},
"add_raw_output": True,
+ "add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
@@ -149,6 +154,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"min_length": task.min_length,
"max_length": task.max_length,
"seed": task.seed,
+ "use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
@@ -193,6 +199,20 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"keys": [],
"name": "generation_kwargs",
},
+ {
+ "description": "Whether to use the `offline_batch_generate` method to "
+ "generate the responses.",
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ },
+ {
+ "description": "If provided, then polling will be done until the "
+ "`ofline_batch_generate` method is able to retrieve the "
+ "results. The value indicate the time to wait between each "
+ "polling.",
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ },
],
},
{
@@ -200,6 +220,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
+ {
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ "name": "add_raw_input",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
@@ -221,6 +246,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to seed a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": EvolInstructGenerator.__module__,
"name": EvolInstructGenerator.__name__,
diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py
index 7903a95fda..2ac460afc4 100644
--- a/tests/unit/steps/tasks/evol_quality/test_base.py
+++ b/tests/unit/steps/tasks/evol_quality/test_base.py
@@ -94,6 +94,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
assert task.dump() == {
"name": "task",
"add_raw_output": True,
+ "add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
@@ -106,6 +107,10 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
+ "structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
@@ -118,6 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"group_generations": task.group_generations,
"include_original_response": task.include_original_response,
"seed": task.seed,
+ "use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
@@ -162,6 +168,20 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [],
},
+ {
+ "description": "Whether to use the `offline_batch_generate` method to "
+ "generate the responses.",
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ },
+ {
+ "description": "If provided, then polling will be done until the "
+ "`ofline_batch_generate` method is able to retrieve the "
+ "results. The value indicate the time to wait between each "
+ "polling.",
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ },
],
},
{
@@ -169,6 +189,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
+ {
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ "name": "add_raw_input",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
@@ -180,6 +205,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to set a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": task.__module__,
"name": task.__class__.__name__,
diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py
index 4a62836b14..cc13681f9f 100644
--- a/tests/unit/steps/tasks/magpie/test_base.py
+++ b/tests/unit/steps/tasks/magpie/test_base.py
@@ -13,6 +13,7 @@
# limitations under the License.
import random
+from typing import Any, Dict
from unittest import mock
import pytest
@@ -30,7 +31,28 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
):
Magpie(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore
+ def test_raise_error_if_system_prompts_weights_do_not_sum_to_one(self) -> None:
+ with pytest.raises(
+ ValueError,
+ match="`*If `system_prompts` attribute is a dictionary containing tuples with"
+ " the system prompts and their probability of being chosen",
+ ):
+ Magpie(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ system_prompt={
+ "system_prompt_1": ("system_prompt", 0.5),
+ "system_prompt_2": ("system_prompt", 0.4),
+ },
+ )
+
def test_outputs(self) -> None:
+ task = Magpie(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ only_instruction=True,
+ )
+
+ assert task.outputs == ["instruction", "model_name"]
+
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1)
assert task.outputs == ["instruction", "response", "model_name"]
@@ -41,10 +63,18 @@ def test_outputs(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
- only_instruction=True,
+ system_prompt={
+ "system_prompt_1": ("system_prompt", 0.5),
+ "system_prompt_2": ("system_prompt", 0.5),
+ },
)
- assert task.outputs == ["instruction", "model_name"]
+ assert task.outputs == [
+ "instruction",
+ "response",
+ "system_prompt_key",
+ "model_name",
+ ]
def test_process(self) -> None:
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1)
@@ -130,7 +160,7 @@ def test_process_with_several_system_prompts(self) -> None:
assert next(task.process(inputs=[{}, {}, {}])) == [
{
"conversation": [
- {"role": "system", "content": "This is a system prompt."},
+ {"role": "system", "content": "This is another system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
@@ -150,7 +180,7 @@ def test_process_with_several_system_prompts(self) -> None:
},
{
"conversation": [
- {"role": "system", "content": "This is another system prompt."},
+ {"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
@@ -365,6 +395,46 @@ def test_process_with_system_prompt_per_row(self) -> None:
},
]
+ def test_process_with_system_prompt_and_probabilities(self) -> None:
+ with mock.patch(
+ "random.choices",
+ side_effect=[
+ ["system_prompt_1"],
+ ["system_prompt_2"],
+ ["system_prompt_1"],
+ ],
+ ):
+ task = Magpie(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ system_prompt={
+ "system_prompt_1": ("system_prompt", 0.6),
+ "system_prompt_2": ("system_prompt", 0.4),
+ },
+ )
+
+ task.load()
+
+ assert next(task.process(inputs=[{}, {}, {}])) == [
+ {
+ "instruction": "Hello Magpie",
+ "response": "Hello Magpie",
+ "system_prompt_key": "system_prompt_1",
+ "model_name": "test",
+ },
+ {
+ "instruction": "Hello Magpie",
+ "response": "Hello Magpie",
+ "system_prompt_key": "system_prompt_2",
+ "model_name": "test",
+ },
+ {
+ "instruction": "Hello Magpie",
+ "response": "Hello Magpie",
+ "system_prompt_key": "system_prompt_1",
+ "model_name": "test",
+ },
+ ]
+
def test_process_only_instruction(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
@@ -388,6 +458,173 @@ def test_process_only_instruction(self) -> None:
},
]
+ @pytest.mark.parametrize(
+ "conversation, include_system_prompt, n_turns, expected",
+ [
+ (
+ [
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ ],
+ False,
+ 1,
+ {"instruction": "Hello Magpie", "response": "Hello user"},
+ ),
+ (
+ [
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ ],
+ False,
+ 1,
+ {"instruction": "Hello Magpie", "response": "Hello user"},
+ ),
+ (
+ [
+ {"role": "system", "content": "This is a system prompt."},
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm fine thank you."},
+ ],
+ True,
+ 2,
+ {
+ "conversation": [
+ {"role": "system", "content": "This is a system prompt."},
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm fine thank you."},
+ ],
+ },
+ ),
+ (
+ [
+ {"role": "system", "content": "This is a system prompt."},
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm fine thank you."},
+ ],
+ False,
+ 2,
+ {
+ "conversation": [
+ {"role": "user", "content": "Hello Magpie"},
+ {"role": "assistant", "content": "Hello user"},
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm fine thank you."},
+ ],
+ },
+ ),
+ (
+ [],
+ False,
+ 1,
+ {"instruction": None, "response": None},
+ ),
+ (
+ [],
+ False,
+ 2,
+ {"conversation": []},
+ ),
+ ],
+ )
+ def test_prepare_conversation_outputs(
+ self,
+ conversation,
+ include_system_prompt: bool,
+ n_turns: int,
+ expected: Dict[str, Any],
+ ) -> None:
+ task = Magpie(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ n_turns=n_turns,
+ include_system_prompt=include_system_prompt,
+ )
+ assert task._prepare_conversation_outputs([conversation], []) == [expected]
+
+ @pytest.mark.parametrize(
+ "system_prompt, n_turns, inputs, random_choices_return, expected_prepared_inputs, expected_system_prompt_keys",
+ [
+ (
+ None,
+ 1,
+ [{"system_prompt": "Custom system prompt."}],
+ None,
+ [[{"role": "system", "content": "Custom system prompt."}]],
+ [],
+ ),
+ (
+ ["Prompt A", "Prompt B"],
+ 1,
+ [{}],
+ ["Prompt A"],
+ [[{"role": "system", "content": "Prompt A"}]],
+ [],
+ ),
+ (
+ {"Key1": "Prompt 1", "Key2": "Prompt 2"},
+ 1,
+ [{}],
+ ["Key1"],
+ [[{"role": "system", "content": "Prompt 1"}]],
+ ["Key1"],
+ ),
+ (
+ {"Key1": ("Prompt 1", 0.7), "Key2": ("Prompt 2", 0.3)},
+ 1,
+ [{}],
+ ["Key1"],
+ [[{"role": "system", "content": "Prompt 1"}]],
+ ["Key1"],
+ ),
+ (
+ None,
+ 2,
+ [{}],
+ None,
+ [[{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}]],
+ [],
+ ),
+ (
+ None,
+ 1,
+ [{}],
+ None,
+ [[]],
+ [],
+ ),
+ ],
+ )
+ def test_prepare_inputs_for_instruction_generation(
+ self,
+ system_prompt,
+ n_turns,
+ inputs,
+ random_choices_return,
+ expected_prepared_inputs,
+ expected_system_prompt_keys,
+ ):
+ task = Magpie(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ n_turns=n_turns,
+ system_prompt=system_prompt,
+ )
+
+ with mock.patch("random.choices") as mock_choices:
+ if random_choices_return is not None:
+ mock_choices.return_value = random_choices_return
+
+ prepared_inputs, system_prompt_keys = (
+ task._prepare_inputs_for_instruction_generation(inputs)
+ )
+
+ assert prepared_inputs == expected_prepared_inputs
+ assert system_prompt_keys == expected_system_prompt_keys
+
def test_serialization(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
@@ -399,6 +636,9 @@ def test_serialization(self) -> None:
"use_magpie_template": True,
"magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n",
"generation_kwargs": {},
+ "use_offline_batch_generation": False,
+ "offline_batch_generation_block_until_done": None,
+ "jobs_ids": None,
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyMagpieLLM",
@@ -422,7 +662,9 @@ def test_serialization(self) -> None:
"input_batch_size": 50,
"group_generations": False,
"add_raw_output": True,
+ "add_raw_input": True,
"num_generations": 1,
+ "use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "llm",
@@ -431,7 +673,17 @@ def test_serialization(self) -> None:
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [{"name": "kwargs", "optional": False}],
- }
+ },
+ {
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ "description": "Whether to use the `offline_batch_generate` method to generate the responses.",
+ },
+ {
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ "description": "If provided, then polling will be done until the `ofline_batch_generate` method is able to retrieve the results. The value indicate the time to wait between each polling.",
+ },
],
},
{
@@ -457,7 +709,7 @@ def test_serialization(self) -> None:
{
"name": "system_prompt",
"optional": True,
- "description": "An optional system prompt or list of system prompts that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
+ "description": "An optional system prompt, or a list of system prompts from which a random one will be chosen, or a dictionary of system prompts from which a random one will be choosen, or a dictionary of system prompts with their probability of being chosen. The random system prompt will be chosen per input/output batch. This system prompt can be used to guide the generation of the instruct LLM and steer it to generate instructions of a certain topic.",
},
{
"name": "resources",
@@ -499,12 +751,18 @@ def test_serialization(self) -> None:
"optional": True,
"description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
},
+ {
+ "name": "add_raw_input",
+ "optional": True,
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ },
{
"name": "num_generations",
"optional": True,
"description": "The number of generations to be produced per input.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.magpie.base",
"name": "Magpie",
diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py
index b72aa91e62..d1d1426351 100644
--- a/tests/unit/steps/tasks/magpie/test_generator.py
+++ b/tests/unit/steps/tasks/magpie/test_generator.py
@@ -28,6 +28,13 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
MagpieGenerator(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore
def test_outputs(self) -> None:
+ task = MagpieGenerator(
+ llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
+ only_instruction=True,
+ )
+
+ assert task.outputs == ["instruction", "model_name"]
+
task = MagpieGenerator(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1
)
@@ -42,10 +49,18 @@ def test_outputs(self) -> None:
task = MagpieGenerator(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
- only_instruction=True,
+ system_prompt={
+ "system_prompt_1": ("system_prompt", 0.5),
+ "system_prompt_2": ("system_prompt", 0.5),
+ },
)
- assert task.outputs == ["instruction", "model_name"]
+ assert task.outputs == [
+ "instruction",
+ "response",
+ "system_prompt_key",
+ "model_name",
+ ]
def test_serialization(self) -> None:
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
@@ -55,6 +70,9 @@ def test_serialization(self) -> None:
"use_magpie_template": True,
"magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n",
"generation_kwargs": {},
+ "use_offline_batch_generation": False,
+ "offline_batch_generation_block_until_done": None,
+ "jobs_ids": None,
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyMagpieLLM",
@@ -78,7 +96,9 @@ def test_serialization(self) -> None:
"batch_size": 50,
"group_generations": False,
"add_raw_output": True,
+ "add_raw_input": True,
"num_generations": 1,
+ "use_default_structured_output": False,
"num_rows": None,
"runtime_parameters_info": [
{
@@ -88,7 +108,17 @@ def test_serialization(self) -> None:
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [{"name": "kwargs", "optional": False}],
- }
+ },
+ {
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ "description": "Whether to use the `offline_batch_generate` method to generate the responses.",
+ },
+ {
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ "description": "If provided, then polling will be done until the `ofline_batch_generate` method is able to retrieve the results. The value indicate the time to wait between each polling.",
+ },
],
},
{
@@ -114,7 +144,7 @@ def test_serialization(self) -> None:
{
"name": "system_prompt",
"optional": True,
- "description": "An optional system prompt or list of system prompts that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
+ "description": "An optional system prompt, or a list of system prompts from which a random one will be chosen, or a dictionary of system prompts from which a random one will be choosen, or a dictionary of system prompts with their probability of being chosen. The random system prompt will be chosen per input/output batch. This system prompt can be used to guide the generation of the instruct LLM and steer it to generate instructions of a certain topic.",
},
{
"name": "resources",
@@ -156,6 +186,11 @@ def test_serialization(self) -> None:
"optional": True,
"description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
},
+ {
+ "name": "add_raw_input",
+ "optional": True,
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ },
{
"name": "num_generations",
"optional": True,
@@ -167,6 +202,7 @@ def test_serialization(self) -> None:
"description": "The number of rows to generate.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.magpie.generator",
"name": "MagpieGenerator",
diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
index f996b23b48..d2be053aa5 100644
--- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py
+++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
@@ -35,6 +35,9 @@ class DummyUserTest(BaseModel):
"cuda_devices": "auto",
"generation_kwargs": {},
"magpie_pre_query_template": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"structured_output": {
"format": "json",
"schema": {
@@ -71,6 +74,9 @@ class DummyUserTest(BaseModel):
"cuda_devices": "auto",
"generation_kwargs": {},
"magpie_pre_query_template": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"structured_output": {
"format": "regex",
"schema": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
diff --git a/tests/unit/steps/tasks/test_argilla_labeller.py b/tests/unit/steps/tasks/test_argilla_labeller.py
new file mode 100644
index 0000000000..926118dd6c
--- /dev/null
+++ b/tests/unit/steps/tasks/test_argilla_labeller.py
@@ -0,0 +1,210 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Any, Dict, List
+
+import pytest
+
+from distilabel.pipeline.local import Pipeline
+from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller
+from distilabel.steps.tasks.typing import ChatItem
+from tests.unit.conftest import DummyAsyncLLM
+
+
+@pytest.fixture
+def fields() -> Dict[str, Any]:
+ return [
+ {
+ "name": "text",
+ "description": "The text of the question",
+ "title": "The text of the question",
+ "settings": {"type": "text"},
+ }
+ ]
+
+
+@pytest.fixture
+def questions() -> List[Dict[str, Any]]:
+ return [
+ {
+ "name": "label_selection",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "label_selection",
+ "options": [
+ {"value": "yes", "text": "Yes"},
+ {"value": "no", "text": "No"},
+ ],
+ },
+ },
+ {
+ "name": "multi_label_selection",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "multi_label_selection",
+ "options": [
+ {"value": "yes", "text": "Yes"},
+ {"value": "no", "text": "No"},
+ ],
+ },
+ },
+ {
+ "name": "rating",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "rating",
+ "options": [
+ {"value": "1", "text": "1"},
+ ],
+ },
+ },
+ {
+ "name": "text",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "text",
+ },
+ },
+ ]
+
+
+@pytest.fixture
+def outputs() -> List[Dict[str, Any]]:
+ return [
+ {
+ "label": "yes",
+ },
+ {
+ "labels": ["yes", "no"],
+ },
+ {
+ "rating": "1",
+ },
+ {
+ "text": "yes",
+ },
+ ]
+
+
+@pytest.fixture
+def records() -> List[Dict[str, Any]]:
+ return [
+ {
+ "fields": {
+ "text": "What is the capital of France?",
+ },
+ "responses": [
+ {
+ "quesion_name": "label_selection",
+ "value": "yes",
+ }
+ ],
+ }
+ ]
+
+
+class TestArgillaLabeller:
+ def test_format_input(
+ self,
+ questions: List[Dict[str, Any]],
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ for question in questions:
+ result: List[ChatItem] = task.format_input(
+ input={
+ "question": question,
+ "fields": fields,
+ "record": records[0],
+ }
+ )
+ assert question["description"] in result[-1]["content"]
+ assert question["title"] in result[-1]["content"]
+ if question["settings"]["type"] in [
+ "label_selection",
+ "multi_label_selection",
+ "span",
+ "rating",
+ ]:
+ assert (
+ question["settings"]["options"][0]["value"] in result[-1]["content"]
+ )
+
+ def test_format_output(
+ self,
+ questions: List[Dict[str, Any]],
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ outputs: List[Dict[str, Any]],
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ for question, output in zip(questions, outputs):
+ task.format_output(
+ input={
+ "question": question,
+ "fields": fields,
+ "record": records[0],
+ },
+ output=json.dumps(output),
+ )
+
+ def test_fail_on_invalid_question_type(
+ self, questions: List[Dict[str, Any]], records: List[Dict[str, Any]]
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ fake_question = questions[0]
+ fake_question["settings"]["type"] = "invalid_type"
+
+ with pytest.raises(ValueError):
+ task.format_input(
+ input={
+ "record": records[0],
+ "question": fake_question,
+ }
+ )
+
+ def test_fail_on_no_question(self, records: List[Dict[str, Any]]) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ with pytest.raises(ValueError):
+ task.format_input(input={"record": records[0]})
diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py
index 745be0ba6c..29341052fb 100644
--- a/tests/unit/steps/tasks/test_base.py
+++ b/tests/unit/steps/tasks/test_base.py
@@ -14,7 +14,7 @@
import sys
from dataclasses import field
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import pytest
from pydantic import ValidationError
@@ -22,42 +22,38 @@
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.base import Task
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import (
+ DummyAsyncLLM,
+ DummyTask,
+ DummyTaskOfflineBatchGeneration,
+)
if TYPE_CHECKING:
- from distilabel.steps.tasks.typing import ChatType
+ pass
-class DummyTask(Task):
- @property
- def inputs(self) -> List[str]:
- return ["instruction", "additional_info"]
-
- def format_input(self, input: Dict[str, Any]) -> "ChatType":
- return [
- {"role": "system", "content": ""},
- {"role": "user", "content": input["instruction"]},
- ]
-
- @property
- def outputs(self) -> List[str]:
- return ["output", "info_from_input"]
-
- def format_output(
- self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
- ) -> Dict[str, Any]:
- return {"output": output, "info_from_input": input["additional_info"]} # type: ignore
-
-
-class DummyRuntimeLLM(DummyLLM):
+class DummyRuntimeLLM(DummyAsyncLLM):
runtime_parameter: RuntimeParameter[int]
runtime_parameter_optional: Optional[RuntimeParameter[int]] = field(default=None)
class TestTask:
+ def test_model_post_init_raise_valuerror_use_offline_batch_generation(self) -> None:
+ with pytest.raises(
+ ValidationError,
+ match="`DummyTask` task cannot be used with offline batch generation",
+ ):
+ DummyTask(llm=DummyAsyncLLM(use_offline_batch_generation=True))
+
+ def test_is_global_with_offline_batch_generation(self) -> None:
+ task = DummyTaskOfflineBatchGeneration(
+ llm=DummyAsyncLLM(use_offline_batch_generation=True)
+ )
+ assert task.is_global is True
+
def test_passing_pipeline(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = DummyTask(name="task", llm=llm, pipeline=pipeline)
assert task.name == "task"
assert task.llm is llm
@@ -67,14 +63,14 @@ def test_passing_pipeline(self) -> None:
def test_within_pipeline_context(self) -> None:
with Pipeline(name="unit-test-pipeline") as pipeline:
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = DummyTask(name="task", llm=llm, pipeline=pipeline)
assert task.name == "task"
assert task.llm is llm
assert task.pipeline == pipeline
def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
- DummyTask(name="task", llm=DummyLLM())
+ DummyTask(name="task", llm=DummyAsyncLLM())
assert "Step 'task' hasn't received a pipeline" in caplog.text
with pytest.raises(
@@ -88,7 +84,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
if sys.version_info < (3, 12)
else "Can't instantiate abstract class Task without an implementation for abstract methods 'format_input', 'format_output'",
):
- Task(name="task", llm=DummyLLM()) # type: ignore
+ Task(name="task", llm=DummyAsyncLLM()) # type: ignore
@pytest.mark.parametrize(
"input, group_generations, expected",
@@ -107,7 +103,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_0", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_0",
@@ -115,7 +117,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_0", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_0",
@@ -123,7 +131,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_0", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_1",
@@ -131,7 +145,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_1", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_1",
@@ -139,7 +159,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_1", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_1",
@@ -147,7 +173,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_1", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_2",
@@ -155,7 +187,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_2", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_2",
@@ -163,7 +201,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_2", "role": "user"},
+ ],
+ },
},
{
"instruction": "test_2",
@@ -171,7 +215,13 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
- "distilabel_metadata": {"raw_output_task": "output"},
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {"content": "", "role": "system"},
+ {"content": "test_2", "role": "user"},
+ ],
+ },
},
],
),
@@ -194,9 +244,48 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
],
"model_name": "test",
"distilabel_metadata": [
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_0",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_0",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_0",
+ "role": "user",
+ },
+ ],
+ },
+ # {"raw_output_task": "output"},
+ # {"raw_output_task": "output"},
+ # {"raw_output_task": "output"},
],
},
{
@@ -210,9 +299,45 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
],
"model_name": "test",
"distilabel_metadata": [
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_1",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_1",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_1",
+ "role": "user",
+ },
+ ],
+ },
],
},
{
@@ -226,9 +351,45 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
],
"model_name": "test",
"distilabel_metadata": [
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
- {"raw_output_task": "output"},
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_2",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_2",
+ "role": "user",
+ },
+ ],
+ },
+ {
+ "raw_output_task": "output",
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "test_2",
+ "role": "user",
+ },
+ ],
+ },
],
},
],
@@ -242,7 +403,7 @@ def test_process(
expected: List[Dict[str, Any]],
) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = DummyTask(
name="task",
llm=llm,
@@ -253,6 +414,94 @@ def test_process(
result = next(task.process(input))
assert result == expected
+ def test_process_overriding_inputs(self) -> None:
+ llm = DummyAsyncLLM()
+ task = DummyTask(
+ name="task",
+ llm=llm,
+ group_generations=False,
+ num_generations=3,
+ input_mappings={"instruction": "instruction_2"},
+ )
+
+ result = next(
+ task.process_applying_mappings(
+ [
+ {
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "additional_info": "info",
+ }
+ ]
+ )
+ )
+
+ assert result == [
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ ]
+
def test_process_with_runtime_parameters(self) -> None:
# 1. Runtime parameters provided
llm = DummyRuntimeLLM() # type: ignore
@@ -268,6 +517,8 @@ def test_process_with_runtime_parameters(self) -> None:
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {},
+ "offline_batch_generation_block_until_done": True,
+ "use_offline_batch_generation": True,
}
# 2. Runtime parameters in init
@@ -283,6 +534,8 @@ def test_process_with_runtime_parameters(self) -> None:
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {},
+ "offline_batch_generation_block_until_done": True,
+ "use_offline_batch_generation": True,
}
# 3. Runtime parameters in init superseded by runtime parameters
@@ -299,15 +552,18 @@ def test_process_with_runtime_parameters(self) -> None:
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {},
+ "offline_batch_generation_block_until_done": True,
+ "use_offline_batch_generation": True,
}
def test_serialization(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = DummyTask(name="task", llm=llm, pipeline=pipeline)
assert task.dump() == {
"name": "task",
"add_raw_output": True,
+ "add_raw_input": True,
"input_mappings": {},
"output_mappings": {},
"resources": {
@@ -320,9 +576,13 @@ def test_serialization(self) -> None:
"input_batch_size": 50,
"llm": {
"generation_kwargs": {},
+ "structured_output": None,
+ "jobs_ids": None,
+ "offline_batch_generation_block_until_done": None,
+ "use_offline_batch_generation": False,
"type_info": {
"module": "tests.unit.conftest",
- "name": "DummyLLM",
+ "name": "DummyAsyncLLM",
},
},
"group_generations": False,
@@ -372,6 +632,20 @@ def test_serialization(self) -> None:
"keys": [],
"name": "generation_kwargs",
},
+ {
+ "description": "Whether to use the `offline_batch_generate` method to "
+ "generate the responses.",
+ "name": "use_offline_batch_generation",
+ "optional": True,
+ },
+ {
+ "description": "If provided, then polling will be done until the "
+ "`ofline_batch_generate` method is able to retrieve the "
+ "results. The value indicate the time to wait between each "
+ "polling.",
+ "name": "offline_batch_generation_block_until_done",
+ "optional": True,
+ },
],
},
{
@@ -379,18 +653,68 @@ def test_serialization(self) -> None:
"name": "add_raw_output",
"optional": True,
},
+ {
+ "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column",
+ "name": "add_raw_input",
+ "optional": True,
+ },
{
"name": "num_generations",
"description": "The number of generations to be produced per input.",
"optional": True,
},
],
+ "use_cache": True,
"type_info": {
- "module": "tests.unit.steps.tasks.test_base",
+ "module": "tests.unit.conftest",
"name": "DummyTask",
},
+ "use_default_structured_output": False,
}
with Pipeline(name="unit-test-pipeline") as pipeline:
new_task = DummyTask.from_dict(task.dump())
assert isinstance(new_task, DummyTask)
+
+ @pytest.mark.parametrize(
+ "add_raw_output, add_raw_input",
+ [
+ (True, False),
+ (False, True),
+ (True, True),
+ (False, False),
+ ],
+ )
+ def test_add_raw_input_and_or_output(
+ self, add_raw_output: bool, add_raw_input: bool
+ ) -> None:
+ task = DummyTask(
+ llm=DummyAsyncLLM(),
+ add_raw_output=add_raw_output,
+ add_raw_input=add_raw_input,
+ )
+ assert task.add_raw_output is add_raw_output
+ assert task.add_raw_input is add_raw_input
+ task.load()
+ input = [
+ {"instruction": "test_0", "additional_info": "additional_info_0"},
+ {"instruction": "test_1", "additional_info": "additional_info_1"},
+ {"instruction": "test_2", "additional_info": "additional_info_2"},
+ ]
+ result = next(task.process(input))
+ import pprint
+
+ pprint.pprint(result)
+
+ if add_raw_output or add_raw_input:
+ assert "distilabel_metadata" in result[0].keys()
+ if add_raw_output:
+ assert (
+ "raw_output_dummy_task_0" in result[0]["distilabel_metadata"].keys()
+ )
+ if add_raw_input:
+ assert (
+ "raw_input_dummy_task_0" in result[0]["distilabel_metadata"].keys()
+ )
+ else:
+ assert "distilabel_metadata" not in result[0].keys()
diff --git a/tests/unit/steps/tasks/test_clair.py b/tests/unit/steps/tasks/test_clair.py
new file mode 100644
index 0000000000..3d16c0bf48
--- /dev/null
+++ b/tests/unit/steps/tasks/test_clair.py
@@ -0,0 +1,74 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Union
+
+import pytest
+
+from distilabel.steps.tasks.clair import CLAIR
+from tests.unit.conftest import DummyLLM
+
+
+class TestCLAIR:
+ def test_format_input(self) -> None:
+ task = CLAIR(llm=DummyLLM())
+ task.load()
+
+ result = task.format_input(
+ input={"task": "TASK", "student_solution": "SOLUTION"}
+ )
+ # System prompt
+ assert (
+ result[0]["content"]
+ == "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."
+ )
+ # User prompt
+ assert (
+ result[1]["content"]
+ == """\
+{task}: TASK
+
+{student_solution}: SOLUTION
+
+-----------------
+
+Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.
+""".strip()
+ )
+
+ @pytest.mark.parametrize(
+ "output, expected",
+ [
+ (None, {"revision": None, "rational": None}),
+ ("WRONG", {"revision": None, "rational": None}),
+ (
+ "{teacher_reasoning}\n\nreasoning\n\n{corrected_student_solution}\n\ncorrected",
+ {"revision": "corrected", "rational": "reasoning"},
+ ),
+ ],
+ )
+ def test_format_output(
+ self,
+ output: Union[str, None],
+ expected: Dict[str, Any],
+ ) -> None:
+ task = CLAIR(llm=DummyLLM())
+ task.load()
+
+ result = task.format_output(
+ output=output,
+ input={},
+ )
+
+ assert result == expected
diff --git a/tests/unit/steps/tasks/test_complexity_scorer.py b/tests/unit/steps/tasks/test_complexity_scorer.py
index 87b5bf33ac..308fe989be 100644
--- a/tests/unit/steps/tasks/test_complexity_scorer.py
+++ b/tests/unit/steps/tasks/test_complexity_scorer.py
@@ -18,14 +18,14 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestComplexityScorer:
def test_format_input(self) -> None:
task = ComplexityScorer(
name="complexity_scorer",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -42,29 +42,44 @@ def test_format_input(self) -> None:
]
@pytest.mark.parametrize(
- "output, expected",
+ "output, use_default_structured_output, expected",
[
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\n",
+ False,
{"scores": [1.0, 2.0, 3.0]},
),
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\njfjfjfjjfjfjf this is noise from the llm\nlallalalala more noise\nand more noise",
+ False,
{"scores": [1.0, 2.0, 3.0]},
),
(
None,
+ False,
+ {"scores": [None, None, None]},
+ ),
+ (
+ '{"scores":[1,2,3]}',
+ True,
+ {"scores": [1.0, 2.0, 3.0]},
+ ),
+ (
+ "wrong",
+ True,
{"scores": [None, None, None]},
),
],
)
def test_format_output(
- self, output: Union[str, None], expected: Dict[str, Any]
+ self,
+ output: Union[str, None],
+ use_default_structured_output: bool,
+ expected: Dict[str, Any],
) -> None:
task = ComplexityScorer(
- name="complexity_scorer",
- llm=DummyLLM(),
- pipeline=Pipeline(name="unit-test-pipeline"),
+ llm=DummyAsyncLLM(),
+ use_default_structured_output=use_default_structured_output,
)
task.load()
diff --git a/tests/unit/steps/tasks/test_genstruct.py b/tests/unit/steps/tasks/test_genstruct.py
index 88ed5f30ec..300237c843 100644
--- a/tests/unit/steps/tasks/test_genstruct.py
+++ b/tests/unit/steps/tasks/test_genstruct.py
@@ -18,14 +18,14 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.genstruct import Genstruct
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestGenstruct:
def test_format_input(self) -> None:
task = Genstruct(
name="genstruct",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -62,7 +62,7 @@ def test_format_output(
) -> None:
task = Genstruct(
name="genstruct",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py
index 0fc04c01d2..dfaa247b91 100644
--- a/tests/unit/steps/tasks/test_improving_text_embeddings.py
+++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py
@@ -67,6 +67,7 @@ def test_process(self, category: str, flatten_tasks: bool) -> None:
add_raw_output=False,
llm=MockLLM(output="[ 'A', 'B', 'C' ]"),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
@@ -123,6 +124,7 @@ def test_process(self) -> None:
add_raw_output=False,
llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
@@ -185,6 +187,7 @@ def test_process(self) -> None:
add_raw_output=False,
llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
assert task.outputs == ["S1", "S2", "S3", "model_name"]
@@ -231,6 +234,7 @@ def test_process(self) -> None:
add_raw_output=False,
llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
@@ -262,6 +266,7 @@ def test_process(self) -> None:
add_raw_output=False,
llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
assert task.outputs == ["input", "positive_document", "model_name"]
@@ -317,6 +322,7 @@ def test_process(self) -> None:
)
),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
assert task.outputs == ["input_text", "label", "misleading_label", "model_name"]
@@ -388,6 +394,7 @@ def test_process(self) -> None:
)
),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
assert task.outputs == [
diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py
index a6f2793285..1b2f9adffa 100644
--- a/tests/unit/steps/tasks/test_instruction_backtranslation.py
+++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py
@@ -76,6 +76,7 @@ def test_process(self) -> None:
name="instruction-backtranslation",
llm=InstructionBacktranslationLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
+ add_raw_input=False,
)
task.load()
diff --git a/tests/unit/steps/tasks/test_pair_rm.py b/tests/unit/steps/tasks/test_pair_rm.py
index 22d8d3c454..104726307d 100644
--- a/tests/unit/steps/tasks/test_pair_rm.py
+++ b/tests/unit/steps/tasks/test_pair_rm.py
@@ -15,11 +15,13 @@
from unittest.mock import MagicMock, patch
import numpy as np
+import pytest
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.pair_rm import PairRM
+@pytest.mark.skip(reason="Not maintained and to be deprecated.")
@patch("llm_blender.Blender")
class TestPairRM:
def test_process(self, mocker: MagicMock) -> None:
@@ -109,5 +111,6 @@ def test_serialization(self, _: MagicMock) -> None:
"optional": True,
},
],
+ "use_cache": True,
"type_info": {"module": "distilabel.steps.tasks.pair_rm", "name": "PairRM"},
}
diff --git a/tests/unit/steps/tasks/test_prometheus_eval.py b/tests/unit/steps/tasks/test_prometheus_eval.py
index b7cf1cd55e..1781ac9e6a 100644
--- a/tests/unit/steps/tasks/test_prometheus_eval.py
+++ b/tests/unit/steps/tasks/test_prometheus_eval.py
@@ -27,7 +27,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.prometheus_eval import _DEFAULT_RUBRICS, PrometheusEval
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
def load_template(template: str) -> Template:
@@ -131,7 +131,7 @@ def test_format_input(
mode=mode, # type: ignore
rubric=rubric, # type: ignore
reference=reference,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -150,7 +150,7 @@ def test_format_input_errors(self) -> None:
mode="absolute",
rubric="helpfulness",
reference=True,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -164,7 +164,7 @@ def test_format_input_errors(self) -> None:
mode="absolute",
rubric="helpfulness",
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -181,7 +181,7 @@ def test_format_input_errors(self) -> None:
mode="relative",
rubric="helpfulness",
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -258,7 +258,7 @@ def test_format_output(
mode=mode, # type: ignore
rubric="factual-validity",
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -278,7 +278,7 @@ def test_custom_rubrics(self) -> None:
"custom": "[A]\nScore 1: A\nScore 2: B\nScore 3: C\nScore 4: D\nScore 5: E"
},
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
@@ -294,7 +294,7 @@ def test_custom_rubrics_errors(self) -> None:
rubric="custom",
rubrics={},
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
with pytest.raises(
@@ -307,7 +307,7 @@ def test_custom_rubrics_errors(self) -> None:
rubric="custom",
rubrics={"custom": 1},
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
# 2. `rubrics` is not compliant with the pre-defined schema
@@ -321,7 +321,7 @@ def test_custom_rubrics_errors(self) -> None:
rubric="custom",
rubrics={"custom": "wrong schema"},
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
# 3. `rubric` is not available in `rubrics`
@@ -337,6 +337,6 @@ def test_custom_rubrics_errors(self) -> None:
"custom": "[A]\nScore 1: A\nScore 2: B\nScore 3: C\nScore 4: D\nScore 5: E"
},
reference=False,
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
diff --git a/tests/unit/steps/tasks/test_quality_scorer.py b/tests/unit/steps/tasks/test_quality_scorer.py
index 554ede1a7e..3929aaaedf 100644
--- a/tests/unit/steps/tasks/test_quality_scorer.py
+++ b/tests/unit/steps/tasks/test_quality_scorer.py
@@ -18,14 +18,14 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.quality_scorer import QualityScorer
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestQualityScorer:
def test_format_input(self) -> None:
task = QualityScorer(
name="quality_scorer",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -45,29 +45,44 @@ def test_format_input(self) -> None:
]
@pytest.mark.parametrize(
- "output, expected",
+ "output, use_default_structured_output, expected",
[
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\n",
+ False,
{"scores": [1.0, 2.0, 3.0]},
),
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\njfjfjfjjfjfjf this is noise from the llm\nlallalalala more noise\nand more noise",
+ False,
{"scores": [1.0, 2.0, 3.0]},
),
(
None,
+ False,
+ {"scores": [None, None, None]},
+ ),
+ (
+ '{"scores":[1,2,3]}',
+ True,
+ {"scores": [1.0, 2.0, 3.0]},
+ ),
+ (
+ "wrong",
+ True,
{"scores": [None, None, None]},
),
],
)
def test_format_output(
- self, output: Union[str, None], expected: Dict[str, Any]
+ self,
+ output: Union[str, None],
+ use_default_structured_output: bool,
+ expected: Dict[str, Any],
) -> None:
task = QualityScorer(
- name="quality_score",
- llm=DummyLLM(),
- pipeline=Pipeline(name="unit-test-pipeline"),
+ llm=DummyAsyncLLM(),
+ use_default_structured_output=use_default_structured_output,
)
task.load()
diff --git a/tests/unit/steps/tasks/test_self_instruct.py b/tests/unit/steps/tasks/test_self_instruct.py
index 1f539d0e3b..76a24497e2 100644
--- a/tests/unit/steps/tasks/test_self_instruct.py
+++ b/tests/unit/steps/tasks/test_self_instruct.py
@@ -14,14 +14,14 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.self_instruct import SelfInstruct
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestSelfInstruct:
def test_format_input(self) -> None:
task = SelfInstruct(
name="self_instruct",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
@@ -37,7 +37,7 @@ def test_format_input(self) -> None:
def test_format_output(self) -> None:
task = SelfInstruct(
name="self_instruct",
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.load()
diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py
index 146bd2bbdd..9dc6b38ae1 100644
--- a/tests/unit/steps/tasks/test_sentence_transformers.py
+++ b/tests/unit/steps/tasks/test_sentence_transformers.py
@@ -24,7 +24,28 @@
GenerateSentencePair,
GenerationAction,
)
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
+
+# from distilabel.llms.base import LLM, AsyncLLM
+
+# if TYPE_CHECKING:
+# from distilabel.llms.typing import GenerateOutput
+# from distilabel.steps.tasks.typing import FormattedInput
+
+# # Defined here too, so that the serde still works
+# class DummyStructuredLLM(LLM):
+# structured_output: Any = None
+# def load(self) -> None:
+# pass
+
+# @property
+# def model_name(self) -> str:
+# return "test"
+
+# def generate(
+# self, input: "FormattedInput", num_generations: int = 1
+# ) -> "GenerateOutput":
+# return ['{ \n "negative": "negative",\n "positive": "positive"\n}' for _ in range(num_generations)]
class TestGenerateSentencePair:
@@ -151,7 +172,10 @@ def test_format_input(
system_prompt: str,
) -> None:
task = GenerateSentencePair(
- llm=DummyLLM(), action=action, triplet=triplet, hard_negative=hard_negative
+ llm=DummyAsyncLLM(),
+ action=action,
+ triplet=triplet,
+ hard_negative=hard_negative,
)
task.load()
content = "## Anchor\n\nThis is a unit test\n"
@@ -286,7 +310,7 @@ def test_format_input_with_context(
) -> None:
context = "This is your context."
task = GenerateSentencePair(
- llm=DummyLLM(),
+ llm=DummyAsyncLLM(),
action=action,
triplet=triplet,
context=context,
@@ -300,11 +324,12 @@ def test_format_input_with_context(
]
@pytest.mark.parametrize(
- "output,triplet,expected",
+ "output,triplet,use_default_structured_output,expected",
[
(
"## Positive\n\nThis is a paraphrase\n## Negative\n\nThis is not a paraphrase",
True,
+ False,
{
"positive": "This is a paraphrase",
"negative": "This is not a paraphrase",
@@ -313,25 +338,66 @@ def test_format_input_with_context(
(
"## Positive\n\nThis is a paraphrase",
True,
+ False,
{"positive": "This is a paraphrase", "negative": None},
),
(
"## Positive\n\nThis is a paraphrase",
False,
+ False,
{"positive": "This is a paraphrase"},
),
(
"random",
False,
+ False,
{"positive": None},
),
+ (
+ '{ \n "negative": "This is not a paraphrase",\n "positive": "This is a paraphrase"\n}',
+ True,
+ True,
+ {
+ "positive": "This is a paraphrase",
+ "negative": "This is not a paraphrase",
+ },
+ ),
+ (
+ '{ \n "positive": "This is a paraphrase"\n}',
+ True,
+ True,
+ {
+ "positive": "This is a paraphrase",
+ },
+ ),
+ (
+ "{ \n random\n}",
+ False,
+ True,
+ {
+ "positive": None,
+ },
+ ),
+ (
+ "{ \n random\n}",
+ True,
+ True,
+ {"positive": None, "negative": None},
+ ),
],
)
def test_format_output(
- self, output: str, triplet: bool, expected: Dict[str, Any]
+ self,
+ output: str,
+ triplet: bool,
+ use_default_structured_output: bool,
+ expected: Dict[str, Any],
) -> None:
task = GenerateSentencePair(
- llm=DummyLLM(), action="paraphrase", triplet=triplet
+ llm=DummyAsyncLLM(),
+ action="paraphrase",
+ triplet=triplet,
+ use_default_structured_output=use_default_structured_output,
)
task.load()
diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py
index 8fac5b58fb..a57d0da7df 100644
--- a/tests/unit/steps/tasks/test_structured_generation.py
+++ b/tests/unit/steps/tasks/test_structured_generation.py
@@ -87,7 +87,9 @@ def test_format_input_with_system_prompt(self) -> None:
def test_process(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
llm = DummyStructuredLLM()
- task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline)
+ task = StructuredGeneration(
+ name="task", llm=llm, pipeline=pipeline, add_raw_input=False
+ )
assert next(
task.process(
[
diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py
new file mode 100644
index 0000000000..e5af171b33
--- /dev/null
+++ b/tests/unit/steps/tasks/test_text_classification.py
@@ -0,0 +1,140 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+import pytest
+
+from distilabel.steps.tasks.text_classification import TextClassification
+from tests.unit.conftest import DummyAsyncLLM
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+
+class TextClassificationLLM(DummyAsyncLLM):
+ n: int = 1
+
+ async def agenerate( # type: ignore
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ if self.n == 1:
+ return [json.dumps({"labels": "label"}) for _ in range(num_generations)]
+ return [
+ json.dumps({"labels": [f"label_{i}" for i in range(self.n)]})
+ for _ in range(num_generations)
+ ]
+
+
+class TestTextClassification:
+ @pytest.mark.parametrize(
+ "n, context, examples, available_labels, default_label, query_title",
+ [
+ (1, "context", None, None, "Unclassified", "User Query"),
+ (1, "", ["example"], ["label1", "label2"], "default", "User Query"),
+ (
+ 1,
+ "",
+ ["example"],
+ {"label1": "explanation 1", "label2": "explanation 2"},
+ "default",
+ "User Query",
+ ),
+ (
+ 3,
+ "",
+ ["example", "other example"],
+ None,
+ "default",
+ "User Query",
+ ),
+ ],
+ )
+ def test_format_input(
+ self,
+ n: int,
+ context: str,
+ examples: Optional[List[str]],
+ available_labels: Optional[Union[List[str], Dict[str, str]]],
+ default_label: Optional[Union[str, List[str]]],
+ query_title: str,
+ ) -> None:
+ task = TextClassification(
+ llm=DummyAsyncLLM(),
+ n=n,
+ context=context,
+ examples=examples,
+ available_labels=available_labels,
+ default_label=default_label,
+ query_title=query_title,
+ )
+ task.load()
+
+ result = task.format_input({"text": "SAMPLE_TEXT"})
+ content = result[1]["content"]
+
+ assert f'respond with "{default_label}"' in content
+ assert "## User Query\n```\nSAMPLE_TEXT\n```" in content
+ assert f'respond with "{default_label}"' in content
+ if n == 1:
+ assert "Provide the label that best describes the text." in content
+ assert '```\n{\n "labels": "label"\n}\n```' in content
+ else:
+ assert (
+ f"Provide a list of {n} labels that best describe the text." in content
+ )
+ assert (
+ '```\n{\n "labels": ["label_0", "label_1", "label_2"]\n}\n```'
+ in content
+ )
+ if available_labels:
+ if isinstance(available_labels, list):
+ assert 'Use the available labels to classify the user query:\navailable_labels = [\n "label1",\n "label2"\n]'
+ if isinstance(available_labels, dict):
+ assert 'Use the available labels to classify the user query:\navailable_labels = [\n "label1", # explanation 1\n "label2", # explanation 2\n]'
+
+ if examples:
+ assert (
+ "## Examples\nHere are some examples to help you understand the task:\n- example\n"
+ in content
+ )
+ else:
+ assert "## Examples" not in content
+ assert (
+ f"Please classify the {query_title.lower()} by assigning the most appropriate labels."
+ in content
+ )
+ assert f"## {query_title}" in content
+
+ @pytest.mark.parametrize(
+ "n, expected",
+ [
+ (1, json.dumps({"labels": "label"})),
+ (3, json.dumps({"labels": ["label_0", "label_1", "label_2"]})),
+ ],
+ )
+ def test_process(self, n: int, expected: str) -> None:
+ task = TextClassification(
+ llm=TextClassificationLLM(n=n), n=n, use_default_structured_output=True
+ )
+ task.load()
+ result = next(task.process([{"text": "SAMPLE_TEXT"}]))
+ assert result[0]["text"] == "SAMPLE_TEXT"
+ assert result[0]["labels"] == json.loads(expected)["labels"]
+ assert (
+ result[0]["distilabel_metadata"]["raw_output_text_classification_0"]
+ == expected
+ )
diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py
index 741830c9d9..2a6abefb22 100644
--- a/tests/unit/steps/tasks/test_text_generation.py
+++ b/tests/unit/steps/tasks/test_text_generation.py
@@ -12,34 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Union
+
import pytest
+from distilabel.errors import DistilabelUserError
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
-from tests.unit.conftest import DummyLLM
+from tests.unit.conftest import DummyAsyncLLM
class TestTextGeneration:
def test_format_input(self) -> None:
+ llm = DummyAsyncLLM()
+ task = TextGeneration(name="task", llm=llm)
+ task.load()
+
+ assert task.format_input({"instruction": "test"}) == [
+ {"role": "user", "content": "test"}
+ ]
+
+ def test_format_input_with_system_prompt(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = TextGeneration(
- name="task", llm=llm, pipeline=pipeline, use_system_prompt=False
+ name="task", llm=llm, pipeline=pipeline, system_prompt="test"
)
+ task.load()
+
+ assert task.format_input({"instruction": "test"}) == [
+ {"role": "system", "content": "test"},
+ {"role": "user", "content": "test"},
+ ]
+
+ def test_format_input_with_row_system_prompt(self) -> None:
+ pipeline = Pipeline(name="unit-test-pipeline")
+ llm = DummyAsyncLLM()
+ task = TextGeneration(name="task", llm=llm, pipeline=pipeline)
+ task.load()
assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [
- {"role": "user", "content": "test"}
+ {"role": "system", "content": "test"},
+ {"role": "user", "content": "test"},
]
- def test_format_input_with_system_prompt(self) -> None:
+ def test_format_input_with_row_system_prompt_and_system_prompt(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = TextGeneration(
- name="task",
- llm=llm,
- pipeline=pipeline,
- use_system_prompt=True,
+ name="task", llm=llm, pipeline=pipeline, system_prompt="i won't be used"
)
+ task.load()
assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [
{"role": "system", "content": "test"},
@@ -48,10 +71,11 @@ def test_format_input_with_system_prompt(self) -> None:
def test_format_input_errors(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = TextGeneration(
name="task", llm=llm, pipeline=pipeline, use_system_prompt=True
)
+ task.load()
with pytest.raises(
ValueError,
@@ -64,18 +88,13 @@ def test_format_input_errors(self) -> None:
):
task.format_input({"instruction": 1})
- with pytest.warns(
- UserWarning,
- match=r"\`use_system_prompt\` is set to \`True\`, but no \`system_prompt\` in input batch, so it will be ignored.",
- ):
- assert task.format_input({"instruction": "test"}) == [
- {"role": "user", "content": "test"}
- ]
-
def test_process(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
- task = TextGeneration(name="task", llm=llm, pipeline=pipeline)
+ llm = DummyAsyncLLM()
+ task = TextGeneration(
+ name="task", llm=llm, pipeline=pipeline, add_raw_input=False
+ )
+ task.load()
assert next(task.process([{"instruction": "test"}])) == [
{
@@ -88,11 +107,78 @@ def test_process(self) -> None:
}
]
+ @pytest.mark.parametrize(
+ "template, columns, sample",
+ [
+ ("{{ instruction }}", "instruction", {"instruction": "INSTRUCTION"}),
+ (
+ "Document:\n{{ document }}\n\nQuestion: {{ question }}\n\nPlease provide a clear and concise answer to the question based on the information in the document and your general knowledge:",
+ ["document", "question"],
+ {"document": "DOCUMENT", "question": "QUESTION"},
+ ),
+ (
+ "Generate a clear, single-sentence instruction based on the following examples:\n\n{% for example in examples %}\nExample {{ loop.index }}:\nInstruction: {{ example }}\n\n{% endfor %}\nNow, generate a new instruction in a similar style:\n",
+ "examples",
+ {"examples": ["example1", "example2"]},
+ ),
+ ],
+ )
+ def test_format_input_custom_columns(
+ self,
+ template: str,
+ columns: Union[str, List[str]],
+ sample: Dict[str, Any],
+ ) -> None:
+ task = TextGeneration(
+ llm=DummyAsyncLLM(),
+ system_prompt=None,
+ template=template,
+ columns=columns,
+ add_raw_input=False,
+ add_raw_output=False,
+ )
+ task.load()
+
+ # Check the input from the sample are present in the formatted input
+ result = task.format_input(sample)[0]["content"]
+ values = list(sample.values())
+
+ if isinstance(values[0], list):
+ values = values[0]
+ assert all(v in result for v in values)
+
+ @pytest.mark.parametrize(
+ "template, columns, sample",
+ [
+ (
+ "This is a {{ custom }} template",
+ "instruction",
+ {"other": "INSTRUCTION"},
+ ),
+ ],
+ )
+ def test_format_input_custom_columns_expected_errors(
+ self,
+ template: str,
+ columns: Union[str, List[str]],
+ sample: Dict[str, Any],
+ ) -> None:
+ task = TextGeneration(
+ llm=DummyAsyncLLM(),
+ system_prompt=None,
+ template=template,
+ columns=columns,
+ add_raw_input=False,
+ add_raw_output=False,
+ )
+ with pytest.raises(DistilabelUserError):
+ task.load()
+
class TestChatGeneration:
def test_format_input(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = ChatGeneration(name="task", llm=llm, pipeline=pipeline)
assert task.format_input(
@@ -109,7 +195,7 @@ def test_format_input(self) -> None:
def test_format_input_errors(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
+ llm = DummyAsyncLLM()
task = ChatGeneration(name="task", llm=llm, pipeline=pipeline)
with pytest.raises(ValueError, match="The last message must be from the user"):
@@ -124,8 +210,10 @@ def test_format_input_errors(self) -> None:
def test_process(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
- task = ChatGeneration(name="task", llm=llm, pipeline=pipeline)
+ llm = DummyAsyncLLM()
+ task = ChatGeneration(
+ name="task", llm=llm, pipeline=pipeline, add_raw_input=False
+ )
assert next(
task.process(
diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py
index fa72ff9442..5565065d61 100644
--- a/tests/unit/steps/tasks/test_ultrafeedback.py
+++ b/tests/unit/steps/tasks/test_ultrafeedback.py
@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, Dict, List, Union
+
+import pytest
from distilabel.llms.base import LLM
from distilabel.llms.typing import GenerateOutput
-from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
class UltraFeedbackLLM(LLM):
+ structured_output: Any = None
+
def load(self) -> None:
pass
@@ -43,14 +46,12 @@ def generate(
class TestUltraFeedback:
def test_process_with_simple_aspect(self) -> None:
- pipeline = Pipeline(name="unit-test-pipeline")
- llm = UltraFeedbackLLM()
-
task = UltraFeedback(
name="ultrafeedback",
aspect="instruction-following",
- llm=llm,
- pipeline=pipeline,
+ llm=UltraFeedbackLLM(),
+ use_default_structured_output=False,
+ add_raw_input=False,
)
task.load()
@@ -70,14 +71,12 @@ def test_process_with_simple_aspect(self) -> None:
]
def test_process_with_complex_aspect(self) -> None:
- pipeline = Pipeline(name="unit-test-pipeline")
- llm = UltraFeedbackLLM()
-
task = UltraFeedback(
name="ultrafeedback",
aspect="truthfulness",
- llm=llm,
- pipeline=pipeline,
+ llm=UltraFeedbackLLM(),
+ use_default_structured_output=False,
+ add_raw_input=False,
)
task.load()
@@ -97,3 +96,66 @@ def test_process_with_complex_aspect(self) -> None:
},
}
]
+
+ @pytest.mark.parametrize(
+ "output, use_default_structured_output, aspect, expected",
+ [
+ (
+ "{ \n random\n}",
+ True,
+ "honesty",
+ {"ratings": [None, None], "rationales": [None, None]},
+ ),
+ (
+ '{ \n "ratings": [\n 1,\n 5\n ]\n ,\n "rationales": [\n "rationale1",\n "rationale2"\n ]}',
+ True,
+ "honesty",
+ {"ratings": [1, 5], "rationales": ["rationale1", "rationale2"]},
+ ),
+ (
+ "{ \n random\n}",
+ True,
+ "helpfulness",
+ {
+ "ratings": [None, None],
+ "rationales": [None, None],
+ "rationales-for-ratings": [None, None],
+ "types": [None, None],
+ },
+ ),
+ (
+ '{ \n "ratings": [\n 1,\n 5\n ]\n ,\n "rationales": [\n "rationale1",\n "rationale2"\n ], "rationales-for-ratings": [\n "rationale1",\n "rationale2"\n ], "types": [\n 1,\n 2\n ]}',
+ True,
+ "helpfulness",
+ {
+ "ratings": [1, 5],
+ "rationales": ["rationale1", "rationale2"],
+ "rationales-for-ratings": ["rationale1", "rationale2"],
+ "types": [1, 2],
+ },
+ ),
+ ],
+ )
+ def test_format_output(
+ self,
+ output: Union[str, None],
+ use_default_structured_output: bool,
+ aspect: str,
+ expected: Dict[str, Any],
+ ) -> None:
+ task = UltraFeedback(
+ llm=UltraFeedbackLLM(),
+ aspect=aspect,
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+
+ result = task.format_output(
+ output=output,
+ input={
+ "instruction": "How much is 2+2?",
+ "generations": ["4", "something weird"],
+ },
+ )
+
+ assert result == expected
diff --git a/tests/unit/steps/tasks/test_urial.py b/tests/unit/steps/tasks/test_urial.py
new file mode 100644
index 0000000000..f31ac0e5e2
--- /dev/null
+++ b/tests/unit/steps/tasks/test_urial.py
@@ -0,0 +1,72 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from distilabel.steps.tasks.urial import URIAL
+from tests.unit.conftest import DummyAsyncLLM
+
+
+class TestURIAL:
+ def test_format_input(self) -> None:
+ task = URIAL(llm=DummyAsyncLLM())
+ task.load()
+ assert task.format_input({"instruction": "test"}) == [
+ {
+ "role": "user",
+ "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:',
+ }
+ ]
+
+ def test_format_input_with_conversation(self) -> None:
+ task = URIAL(llm=DummyAsyncLLM())
+ task.load()
+ assert task.format_input(
+ {
+ "conversation": [
+ {"role": "user", "content": "test"},
+ {"role": "assistant", "content": "test"},
+ {"role": "user", "content": "test"},
+ ]
+ }
+ ) == [
+ {
+ "role": "user",
+ "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:\n\ntest\n\n# User:\n\ntest\n\n# Assistant:',
+ }
+ ]
+
+ def test_format_input_raise_valueerror(self) -> None:
+ task = URIAL(llm=DummyAsyncLLM())
+ task.load()
+
+ with pytest.raises(ValueError, match="The last message must be from the user."):
+ assert task.format_input(
+ {
+ "conversation": [
+ {"role": "user", "content": "test"},
+ {"role": "assistant", "content": "test"},
+ ]
+ }
+ )
+
+ def test_format_output(self) -> None:
+ task = URIAL(llm=DummyAsyncLLM())
+ task.load()
+
+ assert task.format_output(
+ output=" \n\noutput\n\n# User:", input={"instruction": "test"}
+ ) == {
+ "generation": "output",
+ }
diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py
index 5a997ada66..6e8297bb06 100644
--- a/tests/unit/steps/test_base.py
+++ b/tests/unit/steps/test_base.py
@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import tempfile
+from pathlib import Path
from typing import List, Optional
import pytest
from pydantic import ValidationError
+from distilabel.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME
from distilabel.mixins.runtime_parameters import RuntimeParameter
-from distilabel.pipeline.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME
from distilabel.pipeline.local import Pipeline
from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput
from distilabel.steps.decorator import step
@@ -27,6 +29,8 @@
class DummyStep(Step):
+ attr1: int = 5
+
@property
def inputs(self) -> List[str]:
return ["instruction"]
@@ -64,6 +68,16 @@ def process(self, inputs: StepInput) -> StepOutput:
class TestStep:
+ def test_signature(self) -> None:
+ step = DummyStep(attr1=5)
+ assert step.signature == "a0ce83adedabec3fba270ec7bc8a52a62cbbee40"
+
+ step = DummyStep(attr1=5)
+ assert step.signature == "a0ce83adedabec3fba270ec7bc8a52a62cbbee40"
+
+ step = DummyStep(attr1=1234)
+ assert step.signature == "c00e67df4f7ed97a2bf8d9b1178d6c728e577c3b"
+
def test_create_step_with_invalid_name(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
@@ -161,7 +175,18 @@ def test_get_inputs(self) -> None:
pipeline=Pipeline(name="unit-test-pipeline"),
input_mappings={"instruction": "prompt"},
)
- assert step.get_inputs() == ["prompt"]
+ assert step.get_inputs() == {"prompt": True}
+
+ def test_get_inputs_with_dict(self) -> None:
+ @step(inputs={"instruction": False, "completion": True}, outputs=["score"])
+ def DummyStepWithDict(input: StepInput):
+ pass
+
+ dummy_step_with_dict = DummyStepWithDict()
+ assert dummy_step_with_dict.get_inputs() == {
+ "instruction": False,
+ "completion": True,
+ }
def test_get_outputs(self) -> None:
step = DummyStep(
@@ -169,7 +194,15 @@ def test_get_outputs(self) -> None:
pipeline=Pipeline(name="unit-test-pipeline"),
output_mappings={"response": "generation"},
)
- assert step.get_outputs() == ["generation"]
+ assert step.get_outputs() == {"generation": True}
+
+ def test_get_outputs_with_dict(self) -> None:
+ @step(outputs={"score": False})
+ def DummyStepWithDict(input: StepInput):
+ pass
+
+ dummy_step_with_dict = DummyStepWithDict()
+ assert dummy_step_with_dict.get_outputs() == {"score": False}
def test_apply_input_mappings(self) -> None:
step = DummyStep(
@@ -193,18 +226,21 @@ def test_apply_input_mappings(self) -> None:
)
)
- assert inputs == [
- [
- {"instruction": "hello 1"},
- {"instruction": "hello 2"},
- {"instruction": "hello 3"},
- ],
- [
- {"instruction": "bye 1"},
- {"instruction": "bye 2"},
- {"instruction": "bye 3"},
- ],
- ]
+ assert inputs == (
+ (
+ [
+ {"instruction": "hello 1"},
+ {"instruction": "hello 2"},
+ {"instruction": "hello 3"},
+ ],
+ [
+ {"instruction": "bye 1"},
+ {"instruction": "bye 2"},
+ {"instruction": "bye 3"},
+ ],
+ ),
+ [{}, {}, {}],
+ )
def test_process_applying_mappings(self) -> None:
step = DummyStep(
@@ -230,6 +266,42 @@ def test_process_applying_mappings(self) -> None:
{"prompt": "hello 3", "generation": "unit test"},
]
+ def test_process_applying_mappings_and_overriden_inputs(self) -> None:
+ step = DummyStep(
+ name="dummy",
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ input_mappings={"instruction": "prompt"},
+ output_mappings={"response": "generation"},
+ )
+
+ outputs = next(
+ step.process_applying_mappings(
+ [
+ {"prompt": "hello 1", "instruction": "overriden 1"},
+ {"prompt": "hello 2", "instruction": "overriden 2"},
+ {"prompt": "hello 3", "instruction": "overriden 3"},
+ ]
+ )
+ )
+
+ assert outputs == [
+ {
+ "prompt": "hello 1",
+ "generation": "unit test",
+ "instruction": "overriden 1",
+ },
+ {
+ "prompt": "hello 2",
+ "generation": "unit test",
+ "instruction": "overriden 2",
+ },
+ {
+ "prompt": "hello 3",
+ "generation": "unit test",
+ "instruction": "overriden 3",
+ },
+ ]
+
def test_connect(self) -> None:
@step(inputs=["instruction"], outputs=["generation"])
def GenerationStep(input: StepInput):
@@ -260,6 +332,48 @@ def routing_batch_function(downstream_step_names: List[str]) -> List[str]:
== routing_batch_function
)
+ def test_set_pipeline_artifacts_path(self) -> None:
+ step = DummyStep()
+ step.set_pipeline_artifacts_path(Path("/tmp"))
+ assert step.artifacts_directory == Path(f"/tmp/{step.name}")
+
+ def test_save_artifact(self) -> None:
+ with tempfile.TemporaryDirectory() as tempdir:
+ pipeline_artifacts_path = Path(tempdir)
+ step = DummyStep()
+ step.load()
+ step.set_pipeline_artifacts_path(pipeline_artifacts_path)
+ step.save_artifact(
+ name="unit-test",
+ write_function=lambda path: Path(path / "file.txt").write_text(
+ "unit test"
+ ),
+ metadata={"unit-test": True},
+ )
+
+ artifact_path = pipeline_artifacts_path / step.name / "unit-test" # type: ignore
+
+ assert artifact_path.is_dir()
+ assert (artifact_path / "file.txt").read_text() == "unit test"
+ assert (artifact_path / "metadata.json").read_text() == '{"unit-test":true}'
+
+ def test_save_artifact_without_setting_path(self) -> None:
+ with tempfile.TemporaryDirectory() as tempdir:
+ pipeline_artifacts_path = Path(tempdir)
+ step = DummyStep()
+ step.load()
+ step.save_artifact(
+ name="unit-test",
+ write_function=lambda path: Path(path / "file.txt").write_text(
+ "unit test"
+ ),
+ metadata={"unit-test": True},
+ )
+
+ artifact_path = pipeline_artifacts_path / step.name / "unit-test" # type: ignore
+
+ assert not artifact_path.exists()
+
class TestGeneratorStep:
def test_is_generator(self) -> None:
@@ -295,6 +409,7 @@ def test_step_dump(self) -> None:
step = DummyStep(name="dummy", pipeline=pipeline)
assert step.dump() == {
"name": "dummy",
+ "attr1": 5,
"input_batch_size": 50,
"input_mappings": {},
"output_mappings": {},
@@ -342,6 +457,7 @@ def test_step_dump(self) -> None:
"optional": True,
},
],
+ "use_cache": True,
TYPE_INFO_KEY: {
"module": "tests.unit.steps.test_base",
"name": "DummyStep",
diff --git a/tests/unit/steps/test_truncate.py b/tests/unit/steps/test_truncate.py
new file mode 100644
index 0000000000..52a07d6642
--- /dev/null
+++ b/tests/unit/steps/test_truncate.py
@@ -0,0 +1,47 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import pytest
+
+from distilabel.steps.truncate import TruncateTextColumn
+
+
+@pytest.mark.parametrize(
+ "max_length, text, tokenizer, expected",
+ [
+ (
+ 10,
+ "This is a sample text that is longer than 10 characters",
+ None,
+ "This is a ",
+ ),
+ (
+ 4,
+ "This is a sample text that is longer than 10 characters",
+ "teknium/OpenHermes-2.5-Mistral-7B",
+ "This is a sample",
+ ),
+ ],
+)
+def test_truncate_row(
+ max_length: int, text: str, tokenizer: Optional[str], expected: str
+) -> None:
+ trunc = TruncateTextColumn(
+ column="text", max_length=max_length, tokenizer=tokenizer
+ )
+ trunc.load()
+
+ assert next(trunc.process([{"text": text}])) == [{"text": expected}]
diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py
index 3eee6c527a..1649a2ff18 100644
--- a/tests/unit/test_distiset.py
+++ b/tests/unit/test_distiset.py
@@ -24,10 +24,11 @@
from upath import UPath
from distilabel.distiset import Distiset
+from distilabel.utils.serialization import write_json
@pytest.fixture(scope="function")
-def distiset():
+def distiset() -> Distiset:
return Distiset(
{
"leaf_step_1": Dataset.from_dict({"a": [1, 2, 3]}),
@@ -43,14 +44,32 @@ def make_fake_file(filename: Path) -> None:
def add_config_to_distiset(distiset: Distiset, folder: Path) -> Distiset:
- from distilabel.distiset import DISTISET_CONFIG_FOLDER
+ from distilabel.constants import DISTISET_CONFIG_FOLDER
pipeline_yaml = folder / DISTISET_CONFIG_FOLDER / "pipeline.yaml"
pipeline_log = folder / DISTISET_CONFIG_FOLDER / "pipeline.log"
make_fake_file(pipeline_yaml)
make_fake_file(pipeline_log)
distiset.pipeline_path = pipeline_yaml
- distiset.pipeline_log_path = pipeline_log
+ distiset.log_filename_path = pipeline_log
+ return distiset
+
+
+def add_artifacts_to_distiset(distiset: Distiset, folder: Path) -> Distiset:
+ from distilabel.constants import DISTISET_ARTIFACTS_FOLDER
+
+ artifacts_folder = folder / DISTISET_ARTIFACTS_FOLDER
+
+ for step in ("leaf_step_1", "leaf_step_2"):
+ step_artifacts_folder = artifacts_folder / step
+ step_artifacts_folder.mkdir(parents=True)
+ artifact_folder = step_artifacts_folder / "artifact"
+ artifact_folder.mkdir()
+ metadata_file = artifact_folder / "metadata.json"
+ write_json(metadata_file, {})
+
+ distiset.artifacts_path = artifacts_folder
+
return distiset
@@ -64,54 +83,77 @@ def test_train_test_split(self, distiset: Distiset) -> None:
@pytest.mark.parametrize("storage_options", [None, {"test": "option"}])
@pytest.mark.parametrize("with_config", [False, True])
+ @pytest.mark.parametrize("with_artifacts", [False, True])
def test_save_to_disk(
self,
distiset: Distiset,
with_config: bool,
+ with_artifacts: bool,
storage_options: Optional[Dict[str, Any]],
) -> None:
full_distiset = copy.deepcopy(distiset)
# Distiset with Distiset
with tempfile.TemporaryDirectory() as tmpdirname:
folder = Path(tmpdirname) / "distiset_folder"
+ another_folder = Path(tmpdirname) / "another_distiset_folder"
+
if with_config:
full_distiset = add_config_to_distiset(full_distiset, folder)
+ if with_artifacts:
+ full_distiset = add_artifacts_to_distiset(full_distiset, folder)
+
full_distiset.save_to_disk(
- folder,
+ another_folder,
save_card=with_config,
save_pipeline_config=with_config,
save_pipeline_log=with_config,
storage_options=storage_options,
)
- assert folder.is_dir()
- assert len(list(folder.iterdir())) == 3
+ assert another_folder.is_dir()
+
+ if with_artifacts:
+ assert len(list(another_folder.iterdir())) == 4
+ else:
+ assert len(list(another_folder.iterdir())) == 3
full_distiset = copy.deepcopy(distiset)
# Distiset with DatasetDict
distiset_with_dict = full_distiset.train_test_split(0.8)
with tempfile.TemporaryDirectory() as tmpdirname:
folder = Path(tmpdirname) / "distiset_folder"
+ another_folder = Path(tmpdirname) / "another_distiset_folder"
+
if with_config:
distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder)
+ if with_artifacts:
+ distiset_with_dict = add_artifacts_to_distiset(
+ distiset_with_dict, folder
+ )
+
distiset_with_dict.save_to_disk(
- folder,
+ another_folder,
save_card=with_config,
save_pipeline_config=with_config,
save_pipeline_log=with_config,
)
- assert folder.is_dir()
- assert len(list(folder.iterdir())) == 3
+ assert another_folder.is_dir()
+ if with_artifacts:
+ assert len(list(another_folder.iterdir())) == 4
+ else:
+ assert len(list(another_folder.iterdir())) == 3
@pytest.mark.parametrize("pathlib_implementation", [Path, UPath])
@pytest.mark.parametrize("storage_options", [None, {"project": "experiments"}])
@pytest.mark.parametrize("with_config", [False, True])
+ @pytest.mark.parametrize("with_artifacts", [False, True])
def test_load_from_disk(
self,
distiset: Distiset,
with_config: bool,
+ with_artifacts: bool,
storage_options: Optional[Dict[str, Any]],
pathlib_implementation: type,
) -> None:
@@ -121,17 +163,25 @@ def test_load_from_disk(
# This way we can test also we work with UPath, using FilePath protocol, as it should
# do the same as S3Path, GCSPath, etc.
folder = pathlib_implementation(tmpdirname) / "distiset_folder"
+ another_folder = (
+ pathlib_implementation(tmpdirname) / "another_distiset_folder"
+ )
+
if with_config:
full_distiset = add_config_to_distiset(full_distiset, folder)
+
+ if with_artifacts:
+ full_distiset = add_artifacts_to_distiset(full_distiset, folder)
+
full_distiset.save_to_disk(
- folder,
+ another_folder,
save_card=with_config,
save_pipeline_config=with_config,
save_pipeline_log=with_config,
storage_options=storage_options,
)
ds = Distiset.load_from_disk(
- folder,
+ another_folder,
storage_options=storage_options,
)
assert isinstance(ds, Distiset)
@@ -141,24 +191,41 @@ def test_load_from_disk(
assert ds.pipeline_path.exists()
assert ds.log_filename_path.exists()
+ if with_artifacts:
+ assert ds.artifacts_path.exists()
+
full_distiset = copy.deepcopy(distiset)
# Distiset with DatasetDict
distiset_with_dict = full_distiset.train_test_split(0.8)
with tempfile.TemporaryDirectory() as tmpdirname:
folder = pathlib_implementation(tmpdirname) / "distiset_folder"
+ another_folder = (
+ pathlib_implementation(tmpdirname) / "another_distiset_folder"
+ )
+
if with_config:
distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder)
- distiset_with_dict.save_to_disk(folder)
- ds = Distiset.load_from_disk(folder, storage_options=storage_options)
+ if with_artifacts:
+ distiset_with_dict = add_artifacts_to_distiset(
+ distiset_with_dict, folder
+ )
- assert folder.is_dir()
+ distiset_with_dict.save_to_disk(another_folder)
+ ds = Distiset.load_from_disk(
+ another_folder, storage_options=storage_options
+ )
+
+ assert another_folder.is_dir()
assert isinstance(ds["leaf_step_1"], DatasetDict)
if with_config:
assert ds.pipeline_path.exists()
assert ds.log_filename_path.exists()
+ if with_artifacts:
+ assert ds.artifacts_path.exists()
+
def test_dataset_card(self, distiset: Distiset) -> None:
# Test the the metadata we generate by default without extracting the already generated content from the HF hub.
# We parse the content and check it's the same as the one we generate.
diff --git a/tests/unit/test_errors.py b/tests/unit/test_errors.py
new file mode 100644
index 0000000000..420c6ca0af
--- /dev/null
+++ b/tests/unit/test_errors.py
@@ -0,0 +1,27 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.errors import DistilabelUserError
+
+
+def test_distilabel_user_error() -> None:
+ msg = DistilabelUserError("This is an error message.")
+ assert str(msg) == "This is an error message."
+ msg = DistilabelUserError(
+ "This is an error message.", page="sections/getting_started/faq/"
+ )
+ assert (
+ str(msg)
+ == "This is an error message.\n\nFor further information visit 'https://distilabel.argilla.io/latest/sections/getting_started/faq/'"
+ )