From 3d07c2206e7c307062a4ab10ab8a4e91cb8cc25e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 9 Jul 2024 12:55:42 +0200 Subject: [PATCH] Add `RayPipeline` class (#769) * Create N replicas per `Step` * Update `_BatchManager` to handle batch sorting uncertainty * Add multiple replicas test * Fix unit tests * Fix `next_expected_seq_no` needed to be updated if `routing_batch_function` * Update `set_next_expected_batch_seq_no` only if no `data` * Fix `next_expected_seq_no` with `routing_batch_function` * Remove prints * Add `StepResource` import * Add missing return type hint * Add `StepResources` docs * Add `get_steps_load_stages` method * Update to load steps in stages * Add `_teardown` method * Add load stages * Add printing info about stages * Refactor load stages to avoid race conditions * Add load stages integration test * Fix unit tests * Add unit tests for new methods * Move send last batch message * Refactor to make it work with routing batch function * Add integration test for load stages & routing batch function * Update docs to tell about resources as runtime parameters * Add missing doc pages * Add `ray>=2.31.0` optional dependency * Initial work for `RayPipeline` * Update to load stages from cache * Fix bugs requesting initial batches * Add integration tests for recovering states from cache * Remove atexit * Move `_ProcessWrapper` to different file * `RayPipeline` mvp * Install `ray` if `python!=3.12` * Assign ray actor name * Fix setting `options` for Ray actor * Set name for all the queues * Add requirements * Add docstrings * Remove unit test * Add extra `resources` * Add `ray` method * Add `ray[default]` as dependency * Add `script_executed_in_ray_cluster` function * Fix step load fail didn't stop the pipeline * Run with `RayPipeline` if detected Ray cluster * Set built dag * Fix unit tests * Add `Pipeline` to `RayPipeline` unit tests * Add `ray_init_kwargs` argument * Add `memory` attribute * Add simple `RayPipeline` integration test * Override `RayPipeline.dump` method * Add docs for `RayPipeline` * Fix close PR docs --- .github/workflows/docs-pr-close.yml | 2 +- .../advanced/scaling_with_ray.md | 210 ++++++++++ mkdocs.yml | 9 +- pyproject.toml | 1 + scripts/install_dependencies.sh | 4 + src/distilabel/pipeline/__init__.py | 3 +- src/distilabel/pipeline/base.py | 62 ++- src/distilabel/pipeline/local.py | 362 +++--------------- src/distilabel/pipeline/ray.py | 309 +++++++++++++++ src/distilabel/pipeline/step_wrapper.py | 315 +++++++++++++++ src/distilabel/steps/base.py | 11 + src/distilabel/utils/logging.py | 5 +- src/distilabel/utils/ray.py | 28 ++ tests/integration/__init__.py | 14 + tests/integration/test_pipe_simple.py | 11 +- tests/integration/test_ray_pipeline.py | 183 +++++++++ tests/unit/pipeline/test_base.py | 34 +- tests/unit/pipeline/test_local.py | 65 +++- tests/unit/pipeline/test_ray.py | 27 ++ tests/unit/steps/argilla/test_base.py | 12 + tests/unit/steps/argilla/test_preference.py | 12 + .../steps/argilla/test_text_generation.py | 12 + .../steps/tasks/evol_instruct/test_base.py | 12 + .../tasks/evol_instruct/test_generator.py | 12 + .../steps/tasks/evol_quality/test_base.py | 12 + tests/unit/steps/tasks/test_base.py | 12 + tests/unit/steps/tasks/test_pair_rm.py | 12 + tests/unit/steps/test_base.py | 14 + tests/unit/steps/test_decorator.py | 2 + tests/unit/utils/test_ray.py | 32 ++ 30 files changed, 1429 insertions(+), 370 deletions(-) create mode 100644 docs/sections/how_to_guides/advanced/scaling_with_ray.md create mode 100644 src/distilabel/pipeline/ray.py create mode 100644 src/distilabel/pipeline/step_wrapper.py create mode 100644 src/distilabel/utils/ray.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_ray_pipeline.py create mode 100644 tests/unit/pipeline/test_ray.py create mode 100644 tests/unit/utils/test_ray.py diff --git a/.github/workflows/docs-pr-close.yml b/.github/workflows/docs-pr-close.yml index 2b60b15db..f38f5b5b0 100644 --- a/.github/workflows/docs-pr-close.yml +++ b/.github/workflows/docs-pr-close.yml @@ -19,7 +19,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install -e .[docs] + run: pip install mike - name: Set git credentials run: | diff --git a/docs/sections/how_to_guides/advanced/scaling_with_ray.md b/docs/sections/how_to_guides/advanced/scaling_with_ray.md new file mode 100644 index 000000000..de7d8ceab --- /dev/null +++ b/docs/sections/how_to_guides/advanced/scaling_with_ray.md @@ -0,0 +1,210 @@ +# Scaling and distributing a pipeline with Ray + +Although the local [Pipeline][distilabel.pipeline.local.Pipeline] based on [`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) + [serving LLMs with an external service](serving_an_llm_for_reuse.md) is enough for executing most of the pipelines used to create SFT and preference datasets, there are scenarios where you might need to scale your pipeline across multiple machines. In such cases, distilabel leverages [Ray](https://www.ray.io/) to distribute the workload efficiently. This allows you to generate larger datasets, reduce execution time, and maximize resource utilization across a cluster of machines, without needing to change a single line of code. + +## Relation between distilabel steps and Ray Actors + +A `distilabel` pipeline consist of several [`Step`][distilabel.steps.base.Step]s. An `Step` is a class that defines a basic life-cycle: + +1. It will load or create the resources (LLMs, clients, etc) required to run its logic. +2. It will run a loop waiting for incoming batches received using a queue. Once it receives one batch, it will process it and put the processed batch into an output queue. +3. When it finish a batch that is the final one or receives a special signal, the loop will finish and the unload logic will be executed. + +So an `Step` needs to maintain a minimum state and the best way to do that with Ray is using [actors](https://docs.ray.io/en/latest/ray-core/actors.html). + +``` mermaid +graph TD + A[Step] -->|has| B[Multiple Replicas] + B -->|wrapped in| C[Ray Actor] + C -->|maintains| D[Step Replica State] + C -->|executes| E[Step Lifecycle] + E -->|1. Load/Create Resources| F[LLMs, Clients, etc.] + E -->|2. Process batches from| G[Input Queue] + E -->|3. Processed batches are put in| H[Output Queue] + E -->|4. Unload| I[Cleanup] + +``` + +## Executing a pipeline with Ray + +The recommended way to execute a `distilabel` pipeline using Ray is using the [Ray Jobs API](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-api). + +Before jumping on the explanation, let's first install the prerequisites: +```bash +pip install distilabel[ray] +``` + +!!! tip + + It's recommended to create a virtual environment. + +For the purpose of explaining how to execute a pipeline with Ray, we'll use the following pipeline throughout the examples: + +```python +from distilabel.llms import vLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromHub +from distilabel.steps.tasks import TextGeneration + +with Pipeline(name="text-generation-ray-pipeline") as pipeline: + load_data_from_hub = LoadDataFromHub(output_mappings={"prompt": "instruction"}) + + text_generation = TextGeneration( + llm=vLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", + ) + ) + + load_data_from_hub >> text_generation + +if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + load_data_from_hub.name: { + "repo_id": "HuggingFaceH4/instruction-dataset", + "split": "test", + }, + text_generation.name: { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 4096, + } + }, + "resources": {"replicas": 2, "gpus": 1}, # (1) + }, + } + ) + + distiset.push_to_hub( + "/text-generation-distilabel-ray" # (2) + ) +``` + +1. We're setting [resources](assigning_resources_to_step.md) for the `text_generation` step and defining that we want two replicas and one GPU per replica. `distilabel` will create two replicas of the step i.e. two actors in the Ray cluster, and each actor will request to be allocated in a node of the cluster that have at least one GPU. You can read more about how Ray manages the resources [here](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#resources). +2. You should modify this and add your user or organization on the Hugging Face Hub. + +It's a basic pipeline with just two steps: one to load a dataset from the Hub with an `instruction` column and one to generate a `response` for that instruction using Llama 3 8B Instruct with [vLLM](/distilabel/components-gallery/llms/vllm/). Simple but enough to demonstrate how to distribute and scale the workload using a Ray cluster! + +### Using Ray Jobs API + +If you don't know the Ray Jobs API then it's recommended to read [Ray Jobs Overview](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-overview). Quick summary: Ray Jobs is the recommended way to execute a job in a Ray cluster as it will handle packaging, deploying and managing the Ray application. + +To execute the pipeline above, we first need to create a directory (kind of a package) with the pipeline script (or scripts) +that we will submit to the Ray cluster: + +```bash +mkdir ray-pipeline +``` + +The content of the directory `ray-pipeline` should be: + +``` +ray-pipeline/ +├── pipeline.py +└── runtime_env.yaml +``` + +The first file contains the code of the pipeline, while the second one (`runtime_env.yaml`) is a specific Ray file containing the [environment dependencies](https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#environment-dependencies) required to run the job: + +```yaml +pip: + - distilabel[ray,vllm] >= 1.3.0 +env_vars: + HF_TOKEN: +``` + +With this file we're basically informing to the Ray cluster that it will have to install `distilabel` with the `vllm` and `ray` extra dependencies to be able to run the job. In addition, we're defining the `HF_TOKEN` environment variable that will be used (by the `push_to_hub` method) to upload the resulting dataset to the Hugging Face Hub. + +After that, we can proceed to execute the `ray` command that will submit the job to the Ray cluster: +```bash +ray job submit \ + --address http://localhost:8265 \ + --working-dir ray-pipeline \ + --runtime-env ray-pipeline/runtime_env.yaml -- python pipeline.py +``` + +What this will do, it's to basically upload the `--working-dir` to the Ray cluster, install the dependencies and then execute the `python pipeline.py` command from the head node. + +## File system requirements + +As described in [Using a file system to pass data to steps](fs_to_pass_data.md), `distilabel` relies on the file system to pass the data to the `GlobalStep`s, so if the pipeline to be executed in the Ray cluster have any `GlobalStep` or do you want to set the `use_fs_to_pass_data=True` of the [run][distilabel.pipeline.local.Pipeline.run] method, then you will need to setup a file system to which all the nodes of the Ray cluster have access: + +```python +if __name__ == "__main__": + distiset = pipeline.run( + parameters={...}, + storage_parameters={"path": "file:///mnt/data"}, # (1) + use_fs_to_pass_data=True, + ) +``` + +1. All the nodes of the Ray cluster should have access to `/mnt/data`. + +## Executing a `RayPipeline` in a cluster with Slurm + +If you have access to an HPC, then you're probably also a user of [Slurm](https://slurm.schedmd.com/), a workload manager typically used on HPCs. We can create Slurm job that takes some nodes and deploy a Ray cluster to run a distributed `distilabel` pipeline: + +```bash +#!/bin/bash +#SBATCH --job-name=distilabel-ray-text-generation +#SBATCH --partition=your-partition +#SBATCH --qos=normal +#SBATCH --nodes=2 # (1) +#SBATCH --exclusive +#SBATCH --ntasks-per-node=1 # (2) +#SBATCH --gpus-per-node=1 # (3) +#SBATCH --time=0:30:00 + +set -ex + +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" + +# Activate virtual environment +source /path/to/virtualenv/.venv/bin/activate + +# Getting the node names +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) + +# Get the IP address of the head node +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# Start Ray head node +port=6379 +ip_head=$head_node_ip:$port +export ip_head +echo "IP Head: $ip_head" + +echo "Starting HEAD at $head_node" +srun --nodes=1 --ntasks=1 -w "$head_node" \ + ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --dashboard-host=0.0.0.0 \ + --block & + +# Give some time to head node to start... +sleep 10 + +# Start Ray worker nodes +worker_num=$((SLURM_JOB_NUM_NODES - 1)) + +# Start from 1 (0 is head node) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" \ + ray start --address "$ip_head" \ + --block & + sleep 5 +done + +# Finally submit the job to the cluster +ray job submit --address http://localhost:8265 --working-dir ray-pipeline -- python -u pipeline.py +``` + +1. In this case, we just want two nodes: one to run the Ray head node and one to run a worker. +2. We just want to run a task per node i.e. the Ray command that starts the head/worker node. +3. We have selected 1 GPU per node, but we could have selected more depending on the pipeline. diff --git a/mkdocs.yml b/mkdocs.yml index a23581e8d..e2f3e359f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -155,14 +155,15 @@ nav: - Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md" - Advanced: - Using the Distiset dataset object: "sections/how_to_guides/advanced/distiset.md" + - Cache and recover pipeline executions: "sections/how_to_guides/advanced/caching.md" - Export data to Argilla: "sections/how_to_guides/advanced/argilla.md" - - Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md" + - Structured data generation: "sections/how_to_guides/advanced/structured_generation.md" + - Specify requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md" - Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md" - - Cache and recover pipeline executions: "sections/how_to_guides/advanced/caching.md" + - Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md" - Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md" - - Structured data generation: "sections/how_to_guides/advanced/structured_generation.md" - Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md" - - Specify requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md" + - Scaling and distributing a pipeline with Ray: "sections/how_to_guides/advanced/scaling_with_ray.md" - Pipeline Samples: - Examples: "sections/pipeline_samples/examples/index.md" - Papers: diff --git a/pyproject.toml b/pyproject.toml index 58f3b3226..263f1aaac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ mistralai = ["mistralai >= 0.1.0"] ollama = ["ollama >= 0.1.7"] openai = ["openai >= 1.0.0"] outlines = ["outlines >= 0.0.40"] +ray = ["ray[default] >= 2.31.0"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"] diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh index dc2ef0f3d..06f52c402 100755 --- a/scripts/install_dependencies.sh +++ b/scripts/install_dependencies.sh @@ -8,4 +8,8 @@ python -m pip install uv uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor]" +if [ "${python_version}" != "(3, 12)" ]; then + uv pip install --system -e .[ray] +fi + uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git diff --git a/src/distilabel/pipeline/__init__.py b/src/distilabel/pipeline/__init__.py index 3c33160de..4a5115170 100644 --- a/src/distilabel/pipeline/__init__.py +++ b/src/distilabel/pipeline/__init__.py @@ -13,9 +13,10 @@ # limitations under the License. from distilabel.pipeline.local import Pipeline +from distilabel.pipeline.ray import RayPipeline from distilabel.pipeline.routing_batch_function import ( routing_batch_function, sample_n_steps, ) -__all__ = ["Pipeline", "routing_batch_function", "sample_n_steps"] +__all__ = ["Pipeline", "RayPipeline", "routing_batch_function", "sample_n_steps"] diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 3e6229f40..4e20e53d7 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -137,8 +137,6 @@ class BasePipeline(ABC, RequirementsMixin, _Serializable): _write_buffer: The buffer that will store the data of the leaf steps of the pipeline while running, so the `Distiset` can be created at the end. It will be created when the pipeline is run. Defaults to `None`. - _logging_parameters: A dictionary containing the parameters that will passed to - `setup_logging` function to initialize the logging. Defaults to `{}`. _fs: The `fsspec` filesystem to be used to store the data of the `_Batch`es passed between the steps. It will be set when the pipeline is run. Defaults to `None`. _storage_base_path: The base path where the data of the `_Batch`es passed between @@ -176,7 +174,7 @@ def __init__( in the final `Distiset`. It contains metadata used by distilabel, for example the raw outputs of the `LLM` without processing would be here, inside `raw_output_...` field. Defaults to `False`. - requirements: List of requirements that must be installed to run the Pipeline. + requirements: List of requirements that must be installed to run the pipeline. Defaults to `None`, but can be helpful to inform in a pipeline to be shared that this requirements must be installed. """ @@ -196,9 +194,6 @@ def __init__( self._batch_manager: Optional["_BatchManager"] = None self._write_buffer: Optional["_WriteBuffer"] = None - self._logging_parameters: Dict[str, Any] = { - "filename": self._cache_location["log_file"] - } self._steps_load_status: Dict[str, int] = {} self._steps_load_status_lock = threading.Lock() @@ -219,6 +214,8 @@ def __init__( self._exception: Union[Exception, None] = None + self._log_queue: Union["Queue[Any]", None] = None + def __enter__(self) -> Self: """Set the global pipeline instance when entering a pipeline context.""" _GlobalPipelineManager.set_pipeline(self) @@ -286,16 +283,6 @@ def _create_signature(self) -> str: return hasher.hexdigest() - def _set_logging_parameters(self, parameters: Dict[str, Any]) -> None: - """Set the parameters that will be passed to the `setup_logging` function to - initialize the logging. - - Args: - parameters: A dictionary with the parameters that will be passed to the - `setup_logging` function. - """ - self._logging_parameters = parameters - def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, @@ -338,10 +325,7 @@ def run( self._set_runtime_parameters(parameters or {}) setup_logging( - **{ - **self._logging_parameters, - "filename": str(self._cache_location["log_file"]), - } + log_queue=self._log_queue, filename=str(self._cache_location["log_file"]) ) # Validate the pipeline DAG to check that all the steps are chainable, there are @@ -692,7 +676,8 @@ def _run_output_queue_loop_in_thread(self) -> threading.Thread: def _output_queue_loop(self) -> None: """Loop to receive the output batches from the steps and manage the flow of the batches through the pipeline.""" - self._initialize_pipeline_execution() + if not self._initialize_pipeline_execution(): + return while self._should_continue_processing(): # type: ignore self._logger.debug("Waiting for output batch from step...") @@ -717,24 +702,32 @@ def _output_queue_loop(self) -> None: # If there is another load stage and all the `last_batch`es from the stage # have been received, then load the next stage. - self._update_stage() + if self._should_load_next_stage(): + if not self._update_stage(): + break self._manage_batch_flow(batch) self._finalize_pipeline_execution() - def _initialize_pipeline_execution(self) -> None: + def _initialize_pipeline_execution(self) -> bool: """Load the steps of the required stage to initialize the pipeline execution, - and requests the initial batches to trigger the batch flowing in the pipeline.""" + and requests the initial batches to trigger the batch flowing in the pipeline. + + Returns: + `True` if initialization went OK, `False` otherwise. + """ # Wait for all the steps to be loaded correctly if not self._run_stage_steps_and_wait(stage=self._current_stage): self._set_steps_not_loaded_exception() - return + return False # Send the "first" batches to the steps so the batches starts flowing through # the input queues and output queue self._request_initial_batches() + return True + def _should_continue_processing(self) -> bool: """Condition for the consume batches from the `output_queue` loop. @@ -773,14 +766,19 @@ def _process_batch(self, batch: "_Batch") -> None: if self._is_step_running(step_name): self._send_last_batch_flag_to_step(step_name) - def _update_stage(self) -> None: + def _update_stage(self) -> bool: """Checks if the steps of next stage should be loaded and updates `_current_stage` - attribute.""" - if self._should_load_next_stage(): - self._current_stage += 1 - if not self._run_stage_steps_and_wait(stage=self._current_stage): - self._set_steps_not_loaded_exception() - return + attribute. + + Returns: + `True` if updating the stage went OK, `False` otherwise. + """ + self._current_stage += 1 + if not self._run_stage_steps_and_wait(stage=self._current_stage): + self._set_steps_not_loaded_exception() + return False + + return True def _should_load_next_stage(self) -> bool: """Returns if the next stage should be loaded. diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 874757a6d..9285a0f21 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -15,29 +15,24 @@ import multiprocessing as mp import signal import sys -import traceback -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast import tblib from distilabel.distiset import create_distiset -from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.pipeline.base import ( BasePipeline, ) -from distilabel.pipeline.batch import _Batch -from distilabel.pipeline.constants import ( - LAST_BATCH_SENT_FLAG, -) -from distilabel.steps.tasks.base import Task +from distilabel.pipeline.ray import RayPipeline +from distilabel.pipeline.step_wrapper import _StepWrapper, _StepWrapperException from distilabel.utils.logging import setup_logging, stop_logging +from distilabel.utils.ray import script_executed_in_ray_cluster if TYPE_CHECKING: from queue import Queue from distilabel.distiset import Distiset - from distilabel.pipeline.typing import StepLoadStatus - from distilabel.steps.base import GeneratorStep, Step, _Step + from distilabel.steps.base import _Step _SUBPROCESS_EXCEPTION: Union[Exception, None] = None @@ -56,6 +51,39 @@ def _init_worker(log_queue: "Queue[Any]") -> None: class Pipeline(BasePipeline): """Local pipeline implementation using `multiprocessing`.""" + def ray( + self, + ray_head_node_url: Optional[str] = None, + ray_init_kwargs: Optional[Dict[str, Any]] = None, + ) -> RayPipeline: + """Creates a `RayPipeline` using the init parameters of this pipeline. This is a + convenient method that can be used to "transform" one common `Pipeline` to a `RayPipeline` + and it's mainly used by the CLI. + + Args: + ray_head_node_url: The URL that can be used to connect to the head node of + the Ray cluster. Normally, you won't want to use this argument as the + recommended way to submit a job to a Ray cluster is using the [Ray Jobs + CLI](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-overview). + Defaults to `None`. + ray_init_kwargs: kwargs that will be passed to the `ray.init` method. Defaults + to `None`. + + Returns: + A `RayPipeline` instance. + """ + pipeline = RayPipeline( + name=self.name, + description=self.description, + cache_dir=self._cache_dir, + enable_metadata=self._enable_metadata, + requirements=self.requirements, + ray_head_node_url=ray_head_node_url, + ray_init_kwargs=ray_init_kwargs, + ) + pipeline.dag = self.dag + return pipeline + def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, @@ -88,11 +116,16 @@ def run( Raises: RuntimeError: If the pipeline fails to load all the steps. """ - log_queue = mp.Queue() + if script_executed_in_ray_cluster(): + print("Script running in Ray cluster... Using `RayPipeline`...") + return self.ray().run( + parameters=parameters, + use_cache=use_cache, + storage_parameters=storage_parameters, + use_fs_to_pass_data=use_fs_to_pass_data, + ) - self._set_logging_parameters( - {"log_queue": log_queue, "filename": self._cache_location["log_file"]} - ) + self._log_queue = cast("Queue[Any]", mp.Queue()) if distiset := super().run( parameters, use_cache, storage_parameters, use_fs_to_pass_data @@ -106,7 +139,7 @@ def run( ctx.Pool( num_processes, initializer=_init_worker, - initargs=(log_queue,), + initargs=(self._log_queue,), ) as pool, ): self._manager = manager @@ -159,7 +192,7 @@ def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> N """ assert self._pool, "Pool is not initialized" - process_wrapper = _ProcessWrapper( + step_wrapper = _StepWrapper( step=step, # type: ignore replica=replica, input_queue=input_queue, @@ -168,7 +201,7 @@ def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> N dry_run=self._dry_run, ) - self._pool.apply_async(process_wrapper.run, error_callback=self._error_callback) + self._pool.apply_async(step_wrapper.run, error_callback=self._error_callback) def _error_callback(self, e: BaseException) -> None: """Error callback that will be called when an error occurs in a `Step` process. @@ -178,9 +211,9 @@ def _error_callback(self, e: BaseException) -> None: """ global _SUBPROCESS_EXCEPTION - # First we check that the exception is a `_ProcessWrapperException`, otherwise, we + # First we check that the exception is a `_StepWrapperException`, otherwise, we # print it out and stop the pipeline, since some errors may be unhandled - if not isinstance(e, _ProcessWrapperException): + if not isinstance(e, _StepWrapperException): self._logger.error(f"❌ Failed with an unhandled exception: {e}") self._stop() return @@ -287,294 +320,3 @@ def _stop(self) -> None: self._stop_load_queue_loop() self._stop_output_queue_loop() - - -class _ProcessWrapperException(Exception): - """Exception to be raised when an error occurs in the `Step` process. - - Attributes: - message: The error message. - step: The `Step` that raised the error. - code: The error code. - subprocess_exception: The exception raised by the subprocess. Defaults to `None`. - """ - - def __init__( - self, - message: str, - step: "_Step", - code: int, - subprocess_exception: Optional[Exception] = None, - ) -> None: - self.message = message - self.step = step - self.code = code - self.subprocess_exception = subprocess_exception - self.formatted_traceback = "".join( - traceback.format_exception(subprocess_exception) - ) - - @classmethod - def create_load_error( - cls, - message: str, - step: "_Step", - subprocess_exception: Optional[Exception] = None, - ) -> "_ProcessWrapperException": - """Creates a `_ProcessWrapperException` for a load error. - - Args: - message: The error message. - step: The `Step` that raised the error. - subprocess_exception: The exception raised by the subprocess. Defaults to `None`. - - Returns: - The `_ProcessWrapperException` instance. - """ - return cls(message, step, 1, subprocess_exception) - - @property - def is_load_error(self) -> bool: - """Whether the error is a load error. - - Returns: - `True` if the error is a load error, `False` otherwise. - """ - return self.code == 1 - - -class _ProcessWrapper: - """Wrapper to run the `Step` in a separate process. - - Attributes: - step: The step to run. - replica: The replica ID assigned. - input_queue: The queue to receive the input data. - output_queue: The queue to send the output data. - load_queue: The queue used to notify the main process that the step has been loaded, - has been unloaded or has failed to load. - """ - - def __init__( - self, - step: Union["Step", "GeneratorStep"], - replica: int, - input_queue: "Queue[_Batch]", - output_queue: "Queue[_Batch]", - load_queue: "Queue[Union[StepLoadStatus, None]]", - dry_run: bool = False, - ) -> None: - """Initializes the `_ProcessWrapper`. - - Args: - step: The step to run. - input_queue: The queue to receive the input data. - output_queue: The queue to send the output data. - load_queue: The queue used to notify the main process that the step has been - loaded, has been unloaded or has failed to load. - dry_run: Flag to ensure we are forcing to run the last batch. - """ - self.step = step - self.replica = replica - self.input_queue = input_queue - self.output_queue = output_queue - self.load_queue = load_queue - self._dry_run = dry_run - - if ( - isinstance(self.step, Task) - and hasattr(self.step, "llm") - and isinstance(self.step.llm, CudaDevicePlacementMixin) - ): - self.step.llm._llm_identifier = self.step.name - - def run(self) -> str: - """The target function executed by the process. This function will also handle - the step lifecycle, executing first the `load` function of the `Step` and then - waiting to receive a batch from the `input_queue` that will be handled by the - `process` method of the `Step`. - - Returns: - The name of the step that was executed. - """ - - try: - self.step.load() - self.step._logger.debug(f"Step '{self.step.name}' loaded!") - except Exception as e: - self.step.unload() - self._notify_load_failed() - raise _ProcessWrapperException.create_load_error( - message=f"Step load failed: {e}", - step=self.step, - subprocess_exception=e, - ) from e - - self._notify_load() - - if self.step.is_generator: - self._generator_step_process_loop() - else: - self._non_generator_process_loop() - - # Just in case `None` sentinel was sent - try: - self.input_queue.get(block=False) - except Exception: - pass - - self.step.unload() - - self._notify_unload() - - self.step._logger.info( - f"🏁 Finished running step '{self.step.name}' (replica ID: {self.replica})" - ) - - return self.step.name # type: ignore - - def _notify_load(self) -> None: - """Notifies that the step has finished executing its `load` function successfully.""" - self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore - - def _notify_unload(self) -> None: - """Notifies that the step has been unloaded.""" - self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore - - def _notify_load_failed(self) -> None: - """Notifies that the step failed to load.""" - self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore - - def _generator_step_process_loop(self) -> None: - """Runs the process loop for a generator step. It will call the `process` method - of the step and send the output data to the `output_queue` and block until the next - batch request is received (i.e. receiving an empty batch from the `input_queue`). - - If the `last_batch` attribute of the batch is `True`, the loop will stop and the - process will finish. - - Raises: - _ProcessWrapperException: If an error occurs during the execution of the - `process` method. - """ - step = cast("GeneratorStep", self.step) - try: - if (batch := self.input_queue.get()) is None: - self.step._logger.info( - f"🛑 Stopping yielding batches from step '{self.step.name}'" - ) - return - - offset = batch.seq_no * step.batch_size # type: ignore - - self.step._logger.info( - f"🧬 Starting yielding batches from generator step '{self.step.name}'." - f" Offset: {offset}" - ) - - for data, last_batch in step.process_applying_mappings(offset=offset): - batch.set_data([data]) - batch.last_batch = self._dry_run or last_batch - self._send_batch(batch) - - if batch.last_batch: - return - - self.step._logger.debug( - f"Step '{self.step.name}' waiting for next batch request..." - ) - if (batch := self.input_queue.get()) is None: - self.step._logger.info( - f"🛑 Stopping yielding batches from step '{self.step.name}'" - ) - return - except Exception as e: - raise _ProcessWrapperException(str(e), self.step, 2, e) from e - - def _non_generator_process_loop(self) -> None: - """Runs the process loop for a non-generator step. It will call the `process` - method of the step and send the output data to the `output_queue` and block until - the next batch is received from the `input_queue`. If the `last_batch` attribute - of the batch is `True`, the loop will stop and the process will finish. - - If an error occurs during the execution of the `process` method and the step is - global, the process will raise a `_ProcessWrapperException`. If the step is not - global, the process will log the error and send an empty batch to the `output_queue`. - - Raises: - _ProcessWrapperException: If an error occurs during the execution of the - `process` method and the step is global. - """ - step = cast("Step", self.step) - while True: - if (batch := self.input_queue.get()) is None: - self.step._logger.info( - f"🛑 Stopping processing batches from step '{self.step.name}'" - ) - break - - if batch == LAST_BATCH_SENT_FLAG: - self.step._logger.debug("Received `LAST_BATCH_SENT_FLAG`. Stopping...") - break - - self.step._logger.info( - f"📦 Processing batch {batch.seq_no} in '{batch.step_name}' (replica ID: {self.replica})" - ) - - if batch.data_path is not None: - self.step._logger.debug(f"Reading batch data from '{batch.data_path}'") - batch.read_batch_data_from_fs() - - result = [] - try: - if self.step.has_multiple_inputs: - result = next(step.process_applying_mappings(*batch.data)) - else: - result = next(step.process_applying_mappings(batch.data[0])) - except Exception as e: - if self.step.is_global: - raise _ProcessWrapperException(str(e), self.step, 2, e) from e - - # Impute step outputs columns with `None` - result = self._impute_step_outputs(batch) - - # if the step is not global then we can skip the batch which means sending - # an empty batch to the output queue - self.step._logger.warning( - f"⚠️ Processing batch {batch.seq_no} with step '{self.step.name}' failed." - " Sending empty batch filled with `None`s..." - ) - self.step._logger.warning( - f"Subprocess traceback:\n\n{traceback.format_exc()}" - ) - finally: - batch.set_data([result]) - self._send_batch(batch) - - if batch.last_batch: - break - - def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]: - """Imputes the step outputs columns with `None` in the batch data. - - Args: - batch: The batch to impute. - """ - result = [] - for row in batch.data[0]: - data = row.copy() - for output in self.step.outputs: - data[output] = None - result.append(data) - return result - - def _send_batch(self, batch: _Batch) -> None: - """Sends a batch to the `output_queue`.""" - if batch.data_path is not None: - self.step._logger.debug(f"Writing batch data to '{batch.data_path}'") - batch.write_batch_data_to_fs() - - self.step._logger.info( - f"📨 Step '{batch.step_name}' sending batch {batch.seq_no} to output queue" - ) - self.output_queue.put(batch) diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py new file mode 100644 index 000000000..2ff95ac88 --- /dev/null +++ b/src/distilabel/pipeline/ray.py @@ -0,0 +1,309 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +from distilabel.distiset import create_distiset +from distilabel.pipeline.base import BasePipeline +from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME +from distilabel.pipeline.step_wrapper import _StepWrapper +from distilabel.utils.logging import setup_logging, stop_logging +from distilabel.utils.serialization import TYPE_INFO_KEY + +if TYPE_CHECKING: + from os import PathLike + from queue import Queue + + from distilabel.distiset import Distiset + from distilabel.steps.base import _Step + + +class RayPipeline(BasePipeline): + """Ray pipeline implementation allowing to run a pipeline in a Ray cluster.""" + + def __init__( + self, + name: str, + description: Optional[str] = None, + cache_dir: Optional[Union[str, "PathLike"]] = None, + enable_metadata: bool = False, + requirements: Optional[List[str]] = None, + ray_head_node_url: Optional[str] = None, + ray_init_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the `RayPipeline` instance. + + Args: + name: The name of the pipeline. + description: A description of the pipeline. Defaults to `None`. + cache_dir: A directory where the pipeline will be cached. Defaults to `None`. + enable_metadata: Whether to include the distilabel metadata column for the pipeline + in the final `Distiset`. It contains metadata used by distilabel, for example + the raw outputs of the `LLM` without processing would be here, inside `raw_output_...` + field. Defaults to `False`. + requirements: List of requirements that must be installed to run the Pipeline. + Defaults to `None`, but can be helpful to inform in a pipeline to be shared + that this requirements must be installed. + ray_head_node_url: The URL that can be used to connect to the head node of + the Ray cluster. Normally, you won't want to use this argument as the + recommended way to submit a job to a Ray cluster is using the [Ray Jobs + CLI](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-overview). + Defaults to `None`. + ray_init_kwargs: kwargs that will be passed to the `ray.init` method. Defaults + to `None`. + """ + super().__init__(name, description, cache_dir, enable_metadata, requirements) + + self._ray_head_node_url = ray_head_node_url + self._ray_init_kwargs = ray_init_kwargs or {} + + def run( + self, + parameters: Optional[Dict[str, Dict[str, Any]]] = None, + use_cache: bool = True, + storage_parameters: Optional[Dict[str, Any]] = None, + use_fs_to_pass_data: bool = False, + ) -> "Distiset": + """Runs the pipeline in the Ray cluster. + + Args: + parameters: A dictionary with the step name as the key and a dictionary with + the runtime parameters for the step as the value. Defaults to `None`. + use_cache: Whether to use the cache from previous pipeline runs. Defaults to + `True`. + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + use_fs_to_pass_data: Whether to use the file system to pass the data of + the `_Batch`es between the steps. Even if this parameter is `False`, the + `Batch`es received by `GlobalStep`s will always use the file system to + pass the data. Defaults to `False`. + + Returns: + The `Distiset` created by the pipeline. + + Raises: + RuntimeError: If the pipeline fails to load all the steps. + """ + self._init_ray() + + self._log_queue = self.QueueClass( + actor_options={"name": f"distilabel-{self.name}-log-queue"} + ) + + if distiset := super().run( + parameters, use_cache, storage_parameters, use_fs_to_pass_data + ): + return distiset + + self._output_queue = self.QueueClass( + actor_options={"name": f"distilabel-{self.name}-output-queue"} + ) + self._load_queue = self.QueueClass( + actor_options={"name": f"distilabel-{self.name}-load-queue"} + ) + self._handle_keyboard_interrupt() + + # Run the loop for receiving the load status of each step + self._load_steps_thread = self._run_load_queue_loop_in_thread() + + # Start a loop to receive the output batches from the steps + self._output_queue_thread = self._run_output_queue_loop_in_thread() + self._output_queue_thread.join() + + self._teardown() + + if self._exception: + stop_logging() + raise self._exception + + distiset = create_distiset( + self._cache_location["data"], + pipeline_path=self._cache_location["pipeline"], + log_filename_path=self._cache_location["log_file"], + enable_metadata=self._enable_metadata, + ) + + stop_logging() + + return distiset + + def _init_ray(self) -> None: + """Inits or connects to a Ray cluster.""" + try: + import ray + except ImportError as ie: + raise ImportError( + "ray is not installed. Please install it using `pip install ray[default]`." + ) from ie + + if self._ray_head_node_url: + ray.init( + self._ray_head_node_url, + runtime_env={"pip": self.requirements}, + **self._ray_init_kwargs, + ) + else: + ray.init(**self._ray_init_kwargs) + + @property + def QueueClass(self) -> Callable: + from ray.util.queue import Queue + + return Queue + + def _create_step_input_queue(self, step_name: str) -> "Queue[Any]": + """Creates an input queue for a step. Override to set actor name. + + Args: + step_name: The name of the step. + + Returns: + The input queue created. + """ + input_queue = self.QueueClass( + actor_options={"name": f"distilabel-{self.name}-input-queue-{step_name}"} + ) + self.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, input_queue) + return input_queue + + def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> None: + """Creates a replica of an `Step` using a Ray Actor. + + Args: + step: The step to run. + input_queue: The input queue to send the data to the step. + replica: The replica ID assigned. + """ + import ray + + @ray.remote + class _StepWrapperRay: + def __init__( + self, step_wrapper: _StepWrapper, log_queue: "Queue[Any]" + ) -> None: + self._step_wrapper = step_wrapper + self._log_queue = log_queue + + def run(self) -> str: + setup_logging(log_queue=self._log_queue) + return self._step_wrapper.run() + + resources: Dict[str, Any] = { + "name": f"distilabel-{self.name}-{step.name}-{replica}" + } + + if step.resources.cpus is not None: + resources["num_cpus"] = step.resources.cpus + + if step.resources.gpus is not None: + resources["num_gpus"] = step.resources.gpus + + if step.resources.memory is not None: + resources["memory"] = step.resources.memory + + if step.resources.resources is not None: + resources["resources"] = step.resources.resources + + _StepWrapperRay = _StepWrapperRay.options(**resources) # type: ignore + + self._logger.debug( + f"Creating Ray actor for '{step.name}' (replica ID: {replica}) with resources:" + f" {resources}" + ) + step_wrapper = _StepWrapperRay.remote( + step_wrapper=_StepWrapper( + step=step, # type: ignore + replica=replica, + input_queue=input_queue, + output_queue=self._output_queue, + load_queue=self._load_queue, + dry_run=self._dry_run, + ), + log_queue=self._log_queue, + ) + + self._logger.debug( + f"Executing remote `run` method of Ray actor for '{step.name}' (replica ID:" + f" {replica})..." + ) + step_wrapper.run.remote() + + def _teardown(self) -> None: + """Clean/release/stop resources reserved to run the pipeline.""" + if self._write_buffer: + self._write_buffer.close() + + if self._batch_manager: + self._batch_manager = None + + self._stop_load_queue_loop() + self._load_steps_thread.join() + + def _set_steps_not_loaded_exception(self) -> None: + pass + + def _stop(self) -> None: + """Stops the pipeline execution. It will first send `None` to the input queues + of all the steps and then wait until the output queue is empty i.e. all the steps + finished processing the batches that were sent before the stop flag. Then it will + send `None` to the output queue to notify the pipeline to stop.""" + with self._stop_called_lock: + if self._stop_called: + self._stop_calls += 1 + if self._stop_calls == 1: + self._logger.warning( + "🛑 Press again to force the pipeline to stop." + ) + elif self._stop_calls > 1: + self._logger.warning("🛑 Forcing pipeline interruption.") + + stop_logging() + + sys.exit(1) + + return + self._stop_called = True + + self._logger.debug( + f"Steps loaded before calling `stop`: {self._steps_load_status}" + ) + self._logger.info( + "🛑 Stopping pipeline. Waiting for steps to finish processing batches..." + ) + + self._stop_load_queue_loop() + self._stop_output_queue_loop() + + def dump(self, **kwargs: Any) -> Dict[str, Any]: + """Dumps the pipeline information. Override to hardcode the type info to `Pipeline`, + as we don't want to create a `RayPipeline` directly but create it using `Pipeline.ray` + method. + + Returns: + The pipeline dump. + """ + from distilabel.pipeline import Pipeline + + dict_ = super().dump() + dict_["pipeline"][TYPE_INFO_KEY] = { + "module": Pipeline.__module__, + "name": Pipeline.__name__, + } + return dict_ diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py new file mode 100644 index 000000000..29c2e3e11 --- /dev/null +++ b/src/distilabel/pipeline/step_wrapper.py @@ -0,0 +1,315 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback +from queue import Queue +from typing import Any, Dict, List, Optional, Union, cast + +from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG +from distilabel.pipeline.typing import StepLoadStatus +from distilabel.steps.base import GeneratorStep, Step, _Step +from distilabel.steps.tasks.base import Task + + +class _StepWrapper: + """Wrapper to run the `Step`. + + Attributes: + step: The step to run. + replica: The replica ID assigned. + input_queue: The queue to receive the input data. + output_queue: The queue to send the output data. + load_queue: The queue used to notify the main process that the step has been loaded, + has been unloaded or has failed to load. + """ + + def __init__( + self, + step: Union["Step", "GeneratorStep"], + replica: int, + input_queue: "Queue[_Batch]", + output_queue: "Queue[_Batch]", + load_queue: "Queue[Union[StepLoadStatus, None]]", + dry_run: bool = False, + ) -> None: + """Initializes the `_ProcessWrapper`. + + Args: + step: The step to run. + input_queue: The queue to receive the input data. + output_queue: The queue to send the output data. + load_queue: The queue used to notify the main process that the step has been + loaded, has been unloaded or has failed to load. + dry_run: Flag to ensure we are forcing to run the last batch. + """ + self.step = step + self.replica = replica + self.input_queue = input_queue + self.output_queue = output_queue + self.load_queue = load_queue + self._dry_run = dry_run + + if ( + isinstance(self.step, Task) + and hasattr(self.step, "llm") + and isinstance(self.step.llm, CudaDevicePlacementMixin) + ): + self.step.llm._llm_identifier = self.step.name + + def run(self) -> str: + """The target function executed by the process. This function will also handle + the step lifecycle, executing first the `load` function of the `Step` and then + waiting to receive a batch from the `input_queue` that will be handled by the + `process` method of the `Step`. + + Returns: + The name of the step that was executed. + """ + + try: + self.step.load() + self.step._logger.debug(f"Step '{self.step.name}' loaded!") + except Exception as e: + self.step.unload() + self._notify_load_failed() + raise _StepWrapperException.create_load_error( + message=f"Step load failed: {e}", + step=self.step, + subprocess_exception=e, + ) from e + + self._notify_load() + + if self.step.is_generator: + self._generator_step_process_loop() + else: + self._non_generator_process_loop() + + # Just in case `None` sentinel was sent + try: + self.input_queue.get(block=False) + except Exception: + pass + + self.step.unload() + + self._notify_unload() + + self.step._logger.info( + f"🏁 Finished running step '{self.step.name}' (replica ID: {self.replica})" + ) + + return self.step.name # type: ignore + + def _notify_load(self) -> None: + """Notifies that the step has finished executing its `load` function successfully.""" + self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore + + def _notify_unload(self) -> None: + """Notifies that the step has been unloaded.""" + self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore + + def _notify_load_failed(self) -> None: + """Notifies that the step failed to load.""" + self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore + + def _generator_step_process_loop(self) -> None: + """Runs the process loop for a generator step. It will call the `process` method + of the step and send the output data to the `output_queue` and block until the next + batch request is received (i.e. receiving an empty batch from the `input_queue`). + + If the `last_batch` attribute of the batch is `True`, the loop will stop and the + process will finish. + + Raises: + _StepWrapperException: If an error occurs during the execution of the + `process` method. + """ + step = cast("GeneratorStep", self.step) + try: + if (batch := self.input_queue.get()) is None: + self.step._logger.info( + f"🛑 Stopping yielding batches from step '{self.step.name}'" + ) + return + + offset = batch.seq_no * step.batch_size # type: ignore + + self.step._logger.info( + f"🧬 Starting yielding batches from generator step '{self.step.name}'." + f" Offset: {offset}" + ) + + for data, last_batch in step.process_applying_mappings(offset=offset): + batch.set_data([data]) + batch.last_batch = self._dry_run or last_batch + self._send_batch(batch) + + if batch.last_batch: + return + + self.step._logger.debug( + f"Step '{self.step.name}' waiting for next batch request..." + ) + if (batch := self.input_queue.get()) is None: + self.step._logger.info( + f"🛑 Stopping yielding batches from step '{self.step.name}'" + ) + return + except Exception as e: + raise _StepWrapperException(str(e), self.step, 2, e) from e + + def _non_generator_process_loop(self) -> None: + """Runs the process loop for a non-generator step. It will call the `process` + method of the step and send the output data to the `output_queue` and block until + the next batch is received from the `input_queue`. If the `last_batch` attribute + of the batch is `True`, the loop will stop and the process will finish. + + If an error occurs during the execution of the `process` method and the step is + global, the process will raise a `_StepWrapperException`. If the step is not + global, the process will log the error and send an empty batch to the `output_queue`. + + Raises: + _StepWrapperException: If an error occurs during the execution of the + `process` method and the step is global. + """ + step = cast("Step", self.step) + while True: + if (batch := self.input_queue.get()) is None: + self.step._logger.info( + f"🛑 Stopping processing batches from step '{self.step.name}'" + ) + break + + if batch == LAST_BATCH_SENT_FLAG: + self.step._logger.debug("Received `LAST_BATCH_SENT_FLAG`. Stopping...") + break + + self.step._logger.info( + f"📦 Processing batch {batch.seq_no} in '{batch.step_name}' (replica ID: {self.replica})" + ) + + if batch.data_path is not None: + self.step._logger.debug(f"Reading batch data from '{batch.data_path}'") + batch.read_batch_data_from_fs() + + result = [] + try: + if self.step.has_multiple_inputs: + result = next(step.process_applying_mappings(*batch.data)) + else: + result = next(step.process_applying_mappings(batch.data[0])) + except Exception as e: + if self.step.is_global: + raise _StepWrapperException(str(e), self.step, 2, e) from e + + # Impute step outputs columns with `None` + result = self._impute_step_outputs(batch) + + # if the step is not global then we can skip the batch which means sending + # an empty batch to the output queue + self.step._logger.warning( + f"⚠️ Processing batch {batch.seq_no} with step '{self.step.name}' failed." + " Sending empty batch filled with `None`s..." + ) + self.step._logger.warning( + f"Subprocess traceback:\n\n{traceback.format_exc()}" + ) + finally: + batch.set_data([result]) + self._send_batch(batch) + + if batch.last_batch: + break + + def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]: + """Imputes the step outputs columns with `None` in the batch data. + + Args: + batch: The batch to impute. + """ + result = [] + for row in batch.data[0]: + data = row.copy() + for output in self.step.outputs: + data[output] = None + result.append(data) + return result + + def _send_batch(self, batch: _Batch) -> None: + """Sends a batch to the `output_queue`.""" + if batch.data_path is not None: + self.step._logger.debug(f"Writing batch data to '{batch.data_path}'") + batch.write_batch_data_to_fs() + + self.step._logger.info( + f"📨 Step '{batch.step_name}' sending batch {batch.seq_no} to output queue" + ) + self.output_queue.put(batch) + + +class _StepWrapperException(Exception): + """Exception to be raised when an error occurs in the `_StepWrapper` class. + + Attributes: + message: The error message. + step: The `Step` that raised the error. + code: The error code. + subprocess_exception: The exception raised by the subprocess. Defaults to `None`. + """ + + def __init__( + self, + message: str, + step: "_Step", + code: int, + subprocess_exception: Optional[Exception] = None, + ) -> None: + self.message = message + self.step = step + self.code = code + self.subprocess_exception = subprocess_exception + self.formatted_traceback = "".join( + traceback.format_exception(subprocess_exception) + ) + + @classmethod + def create_load_error( + cls, + message: str, + step: "_Step", + subprocess_exception: Optional[Exception] = None, + ) -> "_StepWrapperException": + """Creates a `_StepWrapperException` for a load error. + + Args: + message: The error message. + step: The `Step` that raised the error. + subprocess_exception: The exception raised by the subprocess. Defaults to `None`. + + Returns: + The `_StepWrapperException` instance. + """ + return cls(message, step, 1, subprocess_exception) + + @property + def is_load_error(self) -> bool: + """Whether the error is a load error. + + Returns: + `True` if the error is a load error, `False` otherwise. + """ + return self.code == 1 diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index 59f3bda80..940b05d81 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -97,6 +97,9 @@ class StepResources(RuntimeParametersMixin, BaseModel): replicas: The number of replicas for the step. cpus: The number of CPUs assigned to each step replica. gpus: The number of GPUs assigned to each step replica. + memory: The memory in bytes required for each step replica. + resources: A dictionary containing the number of custom resources required for + each step replica. """ replicas: RuntimeParameter[PositiveInt] = Field( @@ -108,6 +111,14 @@ class StepResources(RuntimeParametersMixin, BaseModel): gpus: Optional[RuntimeParameter[PositiveInt]] = Field( default=None, description="The number of GPUs assigned to each step replica." ) + memory: Optional[RuntimeParameter[PositiveInt]] = Field( + default=None, description="The memory in bytes required for each step replica." + ) + resources: Optional[RuntimeParameter[Dict[str, int]]] = Field( + default=None, + description="A dictionary containing names of custom resources and the" + " number of those resources required for each step replica.", + ) class _Step(RuntimeParametersMixin, RequirementsMixin, BaseModel, _Serializable, ABC): diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py index fd616f000..c69ebcda1 100644 --- a/src/distilabel/utils/logging.py +++ b/src/distilabel/utils/logging.py @@ -96,5 +96,8 @@ def stop_logging() -> None: global queue_listener if queue_listener is not None: queue_listener.stop() - queue_listener.queue.close() + if hasattr(queue_listener.queue, "close"): + queue_listener.queue.close() # type: ignore + if hasattr(queue_listener.queue, "shutdown"): + queue_listener.queue.shutdown() # type: ignore queue_listener = None diff --git a/src/distilabel/utils/ray.py b/src/distilabel/utils/ray.py new file mode 100644 index 000000000..4325e264b --- /dev/null +++ b/src/distilabel/utils/ray.py @@ -0,0 +1,28 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + + +def script_executed_in_ray_cluster() -> bool: + """Checks if running in a Ray cluster. The checking is based on the presence of + typical Ray environment variables that are set in each node of the cluster. + + Returns: + `True` if running on a Ray cluster, `False` otherwise. + """ + return all( + env in os.environ + for env in ["RAY_NODE_TYPE_NAME", "RAY_CLUSTER_NAME", "RAY_ADDRESS"] + ) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..20ce00bda --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index 31a624b15..eee334677 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Generator, List +from typing import TYPE_CHECKING, Dict, List from distilabel.distiset import Distiset from distilabel.mixins.runtime_parameters import RuntimeParameter @@ -20,6 +20,9 @@ from distilabel.steps.base import Step, StepInput from distilabel.steps.generators.data import LoadDataFromDicts +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + DATA = [ {"prompt": "Tell me a joke"}, {"prompt": "Write a short haiku"}, @@ -105,7 +108,7 @@ class RenameColumns(Step): - rename_mappings: RuntimeParameter[Dict[str, str]] + rename_mappings: RuntimeParameter[Dict[str, str]] = None @property def inputs(self) -> List[str]: @@ -115,7 +118,7 @@ def inputs(self) -> List[str]: def outputs(self) -> List[str]: return list(self.rename_mappings.values()) # type: ignore - def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]: + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore outputs = [] for input in inputs: outputs.append( @@ -129,7 +132,7 @@ class GenerateResponse(Step): def inputs(self) -> List[str]: return ["instruction"] - def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]: + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore import time time.sleep(1) diff --git a/tests/integration/test_ray_pipeline.py b/tests/integration/test_ray_pipeline.py new file mode 100644 index 000000000..9f63212a7 --- /dev/null +++ b/tests/integration/test_ray_pipeline.py @@ -0,0 +1,183 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import TYPE_CHECKING, Dict, List + +import pytest +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.pipeline.ray import RayPipeline +from distilabel.steps.base import Step, StepInput +from distilabel.steps.generators.data import LoadDataFromDicts + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + +DATA = [ + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, + {"prompt": "Tell me a joke"}, + {"prompt": "Write a short haiku"}, + {"prompt": "Translate 'My name is Alvaro' to Spanish"}, + {"prompt": "What's the capital of Spain?"}, +] + + +class RenameColumns(Step): + rename_mappings: RuntimeParameter[Dict[str, str]] = None + + @property + def inputs(self) -> List[str]: + return [] + + @property + def outputs(self) -> List[str]: + return list(self.rename_mappings.values()) # type: ignore + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + outputs = [] + for input in inputs: + outputs.append( + {self.rename_mappings.get(k, k): v for k, v in input.items()} # type: ignore + ) + yield outputs + + +class GenerateResponse(Step): + @property + def inputs(self) -> List[str]: + return ["instruction"] + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + import time + + time.sleep(1) + + for input in inputs: + input["response"] = "I don't know" + + yield inputs + + @property + def outputs(self) -> List[str]: + return ["response"] + + +@pytest.mark.skipif( + sys.version_info >= (3, 12), reason="`ray` is not compatible with `python>=3.12`" +) +def test_run_pipeline() -> None: + import ray + from ray.cluster_utils import Cluster + + # TODO: if we add more tests, this should be a fixture + cluster = Cluster(initialize_head=True, head_node_args={"num_cpus": 10}) + ray.init(address=cluster.address) + + with RayPipeline( + name="unit-test-pipeline", ray_init_kwargs={"ignore_reinit_error": True} + ) as pipeline: + load_dataset = LoadDataFromDicts(name="load_dataset", data=DATA, batch_size=8) + rename_columns = RenameColumns(name="rename_columns", input_batch_size=12) + generate_response = GenerateResponse( + name="generate_response", input_batch_size=16 + ) + + load_dataset >> rename_columns >> generate_response + + distiset = pipeline.run( + parameters={ + "rename_columns": { + "rename_mappings": { + "prompt": "instruction", + }, + }, + } + ) + + assert len(distiset["default"]["train"]) == 80 diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 90231eec1..26773f9a4 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -130,12 +130,6 @@ def test_setup_write_buffer(self) -> None: pipeline._setup_write_buffer() assert isinstance(pipeline._write_buffer, _WriteBuffer) - def test_set_logging_parameters(self) -> None: - pipeline = DummyPipeline(name="unit-test-pipeline") - pipeline._set_logging_parameters({"unit-test": "yes"}) - - assert pipeline._logging_parameters == {"unit-test": "yes"} - def test_setup_fsspec(self) -> None: pipeline = DummyPipeline(name="unit-test-pipeline") @@ -850,7 +844,7 @@ class DummyStep1(Step): default=None, description="runtime_param2 description" ) - def process(self, inputs: StepInput) -> None: + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore pass class DummyStep2(Step): @@ -861,7 +855,7 @@ class DummyStep2(Step): default=None, description="runtime_param4 description" ) - def process(self, inputs: StepInput) -> None: + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore pass with DummyPipeline(name="unit-test-pipeline") as pipeline: @@ -888,6 +882,16 @@ def process(self, inputs: StepInput) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { @@ -926,6 +930,16 @@ def process(self, inputs: StepInput) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { @@ -1223,7 +1237,7 @@ def test_pipeline_to_from_file_format( pipe_from_file = loader(filename) assert isinstance(pipe_from_file, DummyPipeline) - def test_base_pipeline_signature(self): + def test_base_pipeline_signature(self) -> None: pipeline = DummyPipeline(name="unit-test-pipeline") # Doesn't matter if it's exactly this or not, the test should fail if we change the # way this is created. @@ -1253,7 +1267,7 @@ def test_base_pipeline_signature(self): ) signature = pipeline._create_signature() - assert signature == "1ef5193c8686de48728cb9e5e9b88bca62bc0957" + assert signature == "f291da215cd42085c538e4897e4355f614932547" def test_binary_rshift_operator(self) -> None: # Tests the steps can be connected using the >> operator. diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py index 3b7f9cd26..7e9d993ac 100644 --- a/tests/unit/pipeline/test_local.py +++ b/tests/unit/pipeline/test_local.py @@ -24,9 +24,9 @@ pass -class TestLocalPipeline: - @mock.patch("distilabel.pipeline.local._ProcessWrapper") - def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: +class TestPipeline: + @mock.patch("distilabel.pipeline.local._StepWrapper") + def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None: with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator = DummyGeneratorStep(name="dummy_generator_step") dummy_step_1 = DummyStep1( @@ -46,7 +46,7 @@ def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: assert pipeline._manager.Queue.call_count == 3 - process_wrapper_mock.assert_has_calls( + step_wrapper_mock.assert_has_calls( [ mock.call( step=dummy_generator, @@ -86,16 +86,67 @@ def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: pipeline._pool.apply_async.assert_has_calls( [ mock.call( - process_wrapper_mock.return_value.run, + step_wrapper_mock.return_value.run, error_callback=pipeline._error_callback, ), mock.call( - process_wrapper_mock.return_value.run, + step_wrapper_mock.return_value.run, error_callback=pipeline._error_callback, ), mock.call( - process_wrapper_mock.return_value.run, + step_wrapper_mock.return_value.run, error_callback=pipeline._error_callback, ), ] ) + + def test_ray(self) -> None: + with Pipeline( + name="dummy", + description="dummy", + cache_dir="/tmp", + enable_metadata=True, + requirements=["dummy"], + ) as pipeline: + generator = DummyGeneratorStep() + dummy = DummyStep1() + + generator >> dummy + + ray_pipeline = pipeline.ray() + + assert ray_pipeline.name == pipeline.name + assert ray_pipeline.description == pipeline.description + assert ray_pipeline._cache_dir == pipeline._cache_dir + assert ray_pipeline._enable_metadata == pipeline._enable_metadata + assert ray_pipeline.requirements == pipeline.requirements + assert ray_pipeline.dag == pipeline.dag + + def test_run_detected_ray(self) -> None: + with Pipeline( + name="dummy", + description="dummy", + cache_dir="/tmp", + enable_metadata=True, + requirements=["dummy"], + ) as pipeline: + generator = DummyGeneratorStep() + dummy = DummyStep1() + + generator >> dummy + + run_pipeline_mock = mock.MagicMock() + + with ( + mock.patch( + "distilabel.pipeline.local.script_executed_in_ray_cluster", + return_value=True, + ), + mock.patch( + "distilabel.pipeline.local.Pipeline.ray", return_value=run_pipeline_mock + ) as ray_mock, + ): + pipeline.run() + + ray_mock.assert_called_once() + run_pipeline_mock.run.assert_called_once() diff --git a/tests/unit/pipeline/test_ray.py b/tests/unit/pipeline/test_ray.py new file mode 100644 index 000000000..b66a7dba5 --- /dev/null +++ b/tests/unit/pipeline/test_ray.py @@ -0,0 +1,27 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.pipeline.ray import RayPipeline +from distilabel.utils.serialization import TYPE_INFO_KEY + + +class TestRayPipeline: + def test_dump(self) -> None: + pipeline = RayPipeline(name="unit-test") + dump = pipeline.dump() + + assert dump["pipeline"][TYPE_INFO_KEY] == { + "module": "distilabel.pipeline.local", + "name": "Pipeline", + } diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index 2ee5d0ad1..def9386b1 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -120,7 +120,9 @@ def test_serialization(self) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": 50, "dataset_name": "argilla", @@ -145,6 +147,16 @@ def test_serialization(self) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index ea3d357d2..1d0137978 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -101,7 +101,9 @@ def test_serialization(self) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": 50, "num_generations": 2, @@ -127,6 +129,16 @@ def test_serialization(self) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/argilla/test_text_generation.py b/tests/unit/steps/argilla/test_text_generation.py index 79eb0c362..f9224367b 100644 --- a/tests/unit/steps/argilla/test_text_generation.py +++ b/tests/unit/steps/argilla/test_text_generation.py @@ -75,7 +75,9 @@ def test_serialization(self) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": 50, "dataset_name": "argilla", @@ -100,6 +102,16 @@ def test_serialization(self) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py index c8df4871e..59ca4fae9 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_base.py +++ b/tests/unit/steps/tasks/evol_instruct/test_base.py @@ -127,7 +127,9 @@ def test_serialization(self, dummy_llm: LLM) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": task.input_batch_size, "llm": { @@ -170,6 +172,16 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py index 6b8b5302b..9f9612148 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_generator.py +++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py @@ -128,7 +128,9 @@ def test_serialization(self, dummy_llm: LLM) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "batch_size": task.batch_size, "num_instructions": task.num_instructions, @@ -165,6 +167,16 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index b11a37f43..0e41ba29f 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -98,7 +98,9 @@ def test_serialization(self, dummy_llm: LLM) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": task.input_batch_size, "llm": { @@ -134,6 +136,16 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index 254089f3d..44cfcb6df 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -202,7 +202,9 @@ def test_serialization(self) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": 50, "llm": { @@ -233,6 +235,16 @@ def test_serialization(self) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/tasks/test_pair_rm.py b/tests/unit/steps/tasks/test_pair_rm.py index 2ded97dac..eeeadffae 100644 --- a/tests/unit/steps/tasks/test_pair_rm.py +++ b/tests/unit/steps/tasks/test_pair_rm.py @@ -64,7 +64,9 @@ def test_serialization(self, _: MagicMock) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "input_batch_size": ranker.input_batch_size, "model": ranker.model, @@ -88,6 +90,16 @@ def test_serialization(self, _: MagicMock) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py index 8c7636886..d66cb927e 100644 --- a/tests/unit/steps/test_base.py +++ b/tests/unit/steps/test_base.py @@ -124,6 +124,8 @@ def process(self, *inputs: StepInput) -> StepOutput: "cpus": True, "gpus": True, "replicas": True, + "memory": True, + "resources": True, }, "runtime_param1": False, "runtime_param2": True, @@ -298,7 +300,9 @@ def test_step_dump(self) -> None: "resources": { "cpus": None, "gpus": None, + "memory": None, "replicas": 1, + "resources": None, }, "runtime_parameters_info": [ { @@ -319,6 +323,16 @@ def test_step_dump(self) -> None: "name": "gpus", "optional": True, }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, ], }, { diff --git a/tests/unit/steps/test_decorator.py b/tests/unit/steps/test_decorator.py index eb12b75b4..4b97cce8e 100644 --- a/tests/unit/steps/test_decorator.py +++ b/tests/unit/steps/test_decorator.py @@ -54,6 +54,8 @@ def UnitTestStep( "cpus": True, "gpus": True, "replicas": True, + "memory": True, + "resources": True, }, "runtime_param1": False, "runtime_param2": True, diff --git a/tests/unit/utils/test_ray.py b/tests/unit/utils/test_ray.py new file mode 100644 index 000000000..5fc918354 --- /dev/null +++ b/tests/unit/utils/test_ray.py @@ -0,0 +1,32 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +from distilabel.utils.ray import script_executed_in_ray_cluster + + +def test_script_executed_on_ray_cluster() -> None: + assert not script_executed_in_ray_cluster() + + with mock.patch.dict( + os.environ, + { + "RAY_NODE_TYPE_NAME": "headgroup", + "RAY_CLUSTER_NAME": "disticluster", + "RAY_ADDRESS": "127.0.0.1:6379", + }, + ): + assert script_executed_in_ray_cluster()