Skip to content

Commit

Permalink
Add load stages (#760)
Browse files Browse the repository at this point in the history
* Create N replicas per `Step`

* Update `_BatchManager` to handle batch sorting uncertainty

* Add multiple replicas test

* Fix unit tests

* Fix `next_expected_seq_no` needed to be updated if
`routing_batch_function`

* Update `set_next_expected_batch_seq_no` only if no `data`

* Fix `next_expected_seq_no` with `routing_batch_function`

* Remove prints

* Add `StepResource` import

* Add missing return type hint

* Add `StepResources` docs

* Add `get_steps_load_stages` method

* Update to load steps in stages

* Add `_teardown` method

* Add load stages

* Add printing info about stages

* Refactor load stages to avoid race conditions

* Add load stages integration test

* Fix unit tests

* Add unit tests for new methods

* Move send last batch message

* Refactor to make it work with routing batch function

* Add integration test for load stages & routing batch function

* Update docs to tell about resources as runtime parameters

* Add missing doc pages

* Update to load stages from cache

* Fix bugs requesting initial batches

* Add integration tests for recovering states from cache

* Remove atexit

* Fix docstring typos

Co-authored-by: Agus <[email protected]>

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
gabrielmbmb and plaguss authored Jul 3, 2024
1 parent 91bc0fa commit 87b8f85
Show file tree
Hide file tree
Showing 16 changed files with 767 additions and 170 deletions.
1 change: 1 addition & 0 deletions docs/api/mixins/requirements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: distilabel.mixins.requirements.RequirementsMixin
1 change: 1 addition & 0 deletions docs/api/mixins/runtime_parameters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: distilabel.mixins.runtime_parameters.RuntimeParametersMixin
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,18 @@ with Pipeline(name="resources") as pipeline:
)
```

In the example above, we're creating a `PrometheusEval` task (remember that `Task`s are `Step`s) that will use `vLLM` to serve `prometheus-eval/prometheus-7b-v2.0` model. This task is resource intensive as it requires an LLM, which in turn requires a GPU to run fast. With that in mind, we have specified the `resources` required for the task using the [`StepResources`][distilabel.steps.base.StepResources] class, and we have defined that we need `1` GPU and `1` CPU per replica of the task. In addition, we have defined that we need `2` replicas i.e. we will run two instances of the task so the computation for the whole dataset runs faster. When running the pipeline, `distilabel` will create the tasks in nodes that have available the specified resources.
In the example above, we're creating a `PrometheusEval` task (remember that `Task`s are `Step`s) that will use `vLLM` to serve `prometheus-eval/prometheus-7b-v2.0` model. This task is resource intensive as it requires an LLM, which in turn requires a GPU to run fast. With that in mind, we have specified the `resources` required for the task using the [`StepResources`][distilabel.steps.base.StepResources] class, and we have defined that we need `1` GPU and `1` CPU per replica of the task. In addition, we have defined that we need `2` replicas i.e. we will run two instances of the task so the computation for the whole dataset runs faster. In addition, `StepResources` uses the [RuntimeParametersMixin][distilabel.mixins.runtime_parameters.RuntimeParametersMixin], so we can also specify the resources for each step when running the pipeline:

```python
...

if __name__ == "__main__":
pipeline.run(
parameters={
prometheus.name: {"resources": {"replicas": 2, "cpus": 1, "gpus": 1}}
}
)
```

And that's it! When running the pipeline, `distilabel` will create the tasks in nodes that have available the specified resources.

5 changes: 4 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ theme:
logo: assets/logo.svg
favicon: assets/logo.svg
icon:
repo: fontawesome/brands/github-alt
repo: fontawesome/brands/github
features:
- navigation.instant
- navigation.sections
Expand Down Expand Up @@ -210,6 +210,9 @@ nav:
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Typing: "api/pipeline/typing.md"
- Utils: "api/pipeline/utils.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
- Distiset: "api/distiset.md"
- CLI: "api/cli.md"
- Community:
Expand Down
43 changes: 42 additions & 1 deletion src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Iterable,
List,
Set,
Tuple,
Type,
Union,
)
Expand Down Expand Up @@ -258,6 +259,46 @@ def get_total_replica_count(self) -> int:
"""
return sum([self.get_step_replica_count(step_name) for step_name in self.G])

def get_steps_load_stages(self) -> Tuple[List[List[str]], List[List[str]]]:
"""Gets the stages in which the `Step`s of the `Pipeline` should be loaded. Stages
are determined by `GlobalStep`s as they receive all the data at once, which means
that a `GlobalStep` is not required to be loaded until all their previous steps
have finished their execution, and the successors of the global step are not required
to be loaded until the global has finished.
Returns:
A tuple with the first element containing asorted list by stage containing
lists with the names of the steps of the stage, and the second element a list
sorted by stage containing lists with the names of the last steps of the stage.
"""

def _get_stage_last_steps(stage_steps: List[str]) -> List[str]:
subgraph = self.G.subgraph(stage_steps)
return sorted(
[node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]
)

stages = []
current_stage = []
stages_last_steps = []

for step_name in nx.topological_sort(self.G):
step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME]
if not step.is_global:
current_stage.append(step_name)
else:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))
stages.append([step_name])
stages_last_steps.append([step_name])
current_stage = []

if current_stage:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))

return stages, stages_last_steps

def validate(self) -> None:
"""Validates that the `Step`s included in the pipeline are correctly connected and
have the correct inputs and outputs.
Expand Down Expand Up @@ -299,7 +340,7 @@ def validate(self) -> None:
# Validate routing batch function (if any)
predecessors = list(self.get_step_predecessors(step.name)) # type: ignore
self._validate_convergence_step(
step,
step, # type: ignore
predecessors,
steps_receiving_routed_batches, # type: ignore
)
Expand Down
Loading

0 comments on commit 87b8f85

Please sign in to comment.